<!DOCTYPE html><html lang="zh-CN" data-theme="light"><head><meta charset="UTF-8"><meta http-equiv="X-UA-Compatible" content="IE=edge"><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=no"><title>深度学习 | 《深度学习入门之PyTorch》阅读笔记 | Justlovesmile's BLOG</title><meta name="keywords" content="深度学习,python,pytorch"><meta name="author" content="Justlovesmile,865717150@qq.com"><meta name="copyright" content="Justlovesmile"><meta name="format-detection" content="telephone=no"><meta name="theme-color" content="#ffffff"><meta name="description" content="深度学习入门之PyTorch第一章 深度学习介绍1.1 人工智能 Artificial Intelligence，人工智能，也称机器智能。 人工智能分为三大类（1）弱人工智能：擅长单方面（2）强人工智能：类似人类等级（3）超人工智能：全方面胜过人类  1.2 数据挖掘，机器学习和深度学习1.2.1 数据挖掘KDD（knowledge discovery in database），从数据中获取有意义"><meta property="og:type" content="article"><meta property="og:title" content="深度学习 | 《深度学习入门之PyTorch》阅读笔记"><meta property="og:url" content="https://blog.justlovesmile.top/posts/bfa4054.html"><meta property="og:site_name" content="Justlovesmile&#39;s BLOG"><meta property="og:description" content="深度学习入门之PyTorch第一章 深度学习介绍1.1 人工智能 Artificial Intelligence，人工智能，也称机器智能。 人工智能分为三大类（1）弱人工智能：擅长单方面（2）强人工智能：类似人类等级（3）超人工智能：全方面胜过人类  1.2 数据挖掘，机器学习和深度学习1.2.1 数据挖掘KDD（knowledge discovery in database），从数据中获取有意义"><meta property="og:locale" content="zh_CN"><meta property="og:image" content="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2@latest/post/pytorch.jpg"><meta property="article:published_time" content="2020-10-23T10:43:24.000Z"><meta property="article:modified_time" content="2020-10-23T10:43:24.000Z"><meta property="article:author" content="Justlovesmile"><meta property="article:tag" content="深度学习"><meta property="article:tag" content="python"><meta property="article:tag" content="pytorch"><meta name="twitter:card" content="summary"><meta name="twitter:image" content="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2@latest/post/pytorch.jpg"><link rel="shortcut icon" href="/img/logo.jpg"><link rel="canonical" href="https://blog.justlovesmile.top/posts/bfa4054"><link rel="preconnect" href="//cdn.jsdelivr.net"><link rel="preconnect" href="//hm.baidu.com"><link rel="stylesheet" href="/css/index.css"><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@fortawesome/fontawesome-free/css/all.min.css" media="print" onload='this.media="all"'><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/node-snackbar/dist/snackbar.min.css" media="print" onload='this.media="all"'><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@fancyapps/ui/dist/fancybox.css" media="print" onload='this.media="all"'><script>var _hmt=_hmt||[];!function(){var e=document.createElement("script");e.src="https://hm.baidu.com/hm.js?a2ee893562999ebad688b0d82daa100a";var t=document.getElementsByTagName("script")[0];t.parentNode.insertBefore(e,t)}()</script><link rel="stylesheet" href="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN/font/family=Titillium+Web.css" media="print" onload='this.media="all"'><script>const GLOBAL_CONFIG={root:"/",algolia:void 0,localSearch:{path:"search.xml",languages:{hits_empty:"找不到您查询的内容：${query}"}},translate:void 0,noticeOutdate:void 0,highlight:{plugin:"highlighjs",highlightCopy:!0,highlightLang:!0,highlightHeightLimit:400},copy:{success:"复制成功",error:"复制错误",noSupport:"浏览器不支持"},relativeDate:{homepage:!1,post:!1},runtime:"天",date_suffix:{just:"刚刚",min:"分钟前",hour:"小时前",day:"天前",month:"个月前"},copyright:{limitCount:100,languages:{author:"作者: Justlovesmile",link:"链接: ",source:"来源: Justlovesmile's BLOG",info:"著作权归作者所有。商业转载请联系作者获得授权，非商业转载请注明出处。"}},lightbox:"fancybox",Snackbar:{chs_to_cht:"你已切换为繁体",cht_to_chs:"你已切换为简体",day_to_night:"你已切换为深色模式",night_to_day:"你已切换为浅色模式",bgLight:"var(--mj-card-bg)",bgDark:"var(--mj-card-bg)",position:"top-right"},source:{justifiedGallery:{js:"https://cdn.jsdelivr.net/npm/flickr-justified-gallery@2/dist/fjGallery.min.js",css:"https://cdn.jsdelivr.net/npm/flickr-justified-gallery@2/dist/fjGallery.min.css"}},isPhotoFigcaption:!1,islazyload:!0,isAnchor:!1}</script><script id="config-diff">var GLOBAL_CONFIG_SITE={title:"深度学习 | 《深度学习入门之PyTorch》阅读笔记",isPost:!0,isHome:!1,isHighlightShrink:!1,isToc:!0,postUpdate:"2020-10-23 18:43:24"}</script><noscript><style>#nav{opacity:1}.justified-gallery img{opacity:1}#post-meta time,#recent-posts time{display:inline!important}</style></noscript><script>(e=>{e.saveToLocal={set:function(e,t,o){if(0===o)return;const a=864e5*o,n={value:t,expiry:(new Date).getTime()+a};localStorage.setItem(e,JSON.stringify(n))},get:function(e){const t=localStorage.getItem(e);if(!t)return;const o=JSON.parse(t);if(!((new Date).getTime()>o.expiry))return o.value;localStorage.removeItem(e)}},e.getScript=e=>new Promise((t,o)=>{const a=document.createElement("script");a.src=e,a.async=!0,a.onerror=o,a.onload=a.onreadystatechange=function(){const e=this.readyState;e&&"loaded"!==e&&"complete"!==e||(a.onload=a.onreadystatechange=null,t())},document.head.appendChild(a)}),e.activateDarkMode=function(){document.documentElement.setAttribute("data-theme","dark"),null!==document.querySelector('meta[name="theme-color"]')&&document.querySelector('meta[name="theme-color"]').setAttribute("content","#0d0d0d")},e.activateLightMode=function(){document.documentElement.setAttribute("data-theme","light"),null!==document.querySelector('meta[name="theme-color"]')&&document.querySelector('meta[name="theme-color"]').setAttribute("content","#ffffff")};const t=saveToLocal.get("theme"),o=(new Date).getHours();void 0===t?o<=6||o>=18?activateDarkMode():activateLightMode():"light"===t?activateLightMode():activateDarkMode();/iPad|iPhone|iPod|Macintosh/.test(navigator.userAgent)&&document.documentElement.classList.add("apple")})(window)</script><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/font-awesome@4.7.0/css/font-awesome.min.css"><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/font-awesome-animation@0.2.1/dist/font-awesome-animation.min.css"><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/swiper/swiper-bundle.min.css"><link rel="stylesheet" href="/css/justlovesmile.css"><link rel="stylesheet" href="/css/blogicon.css"><meta name="generator" content="Hexo 5.4.0"><link rel="alternate" href="/atom.xml" title="Justlovesmile's BLOG" type="application/atom+xml"></head><body><div id="web_bg"></div><div id="sidebar"><div id="menu-mask"></div><div id="sidebar-menus"><div class="avatar-img is-center"><img src="" data-lazy-src="/img/avatar.jpg" onerror='onerror=null,src="/img/friend_404.gif"' alt="avatar"></div><div class="site-data is-center"><div class="data-item"><a href="/archives/"><div class="headline">文章</div><div class="length-num">75</div></a></div><div class="data-item"><a href="/tags/"><div class="headline">标签</div><div class="length-num">69</div></a></div><div class="data-item"><a href="/categories/"><div class="headline">分类</div><div class="length-num">6</div></a></div></div><hr><div class="menus_items"><div class="menus_item"><a class="site-page group" href="javascript:void(0);"><span>网站</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/"><i class="fa-fw fas fa-home"></i> <span>首页</span></a></li><li><a class="site-page child" href="/guestbook/"><i class="fa-fw fas fa-pencil-alt"></i> <span>留言</span></a></li><li><a class="site-page child" href="/friends/"><i class="fa-fw fas fa-paper-plane"></i> <span>友链</span></a></li></ul></div><div class="menus_item"><a class="site-page group" href="javascript:void(0);"><span>文库</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/tags/"><i class="fa-fw fas fa-tags"></i> <span>全部标签</span></a></li><li><a class="site-page child" href="/categories/"><i class="fa-fw fas fa-folder-open"></i> <span>全部分类</span></a></li><li><a class="site-page child" href="/archives/"><i class="fa-fw fas fa-calendar"></i> <span>文章列表</span></a></li><li><a class="site-page child" href="/random/"><i class="fa-fw fas fa-shoe-prints"></i> <span>随便逛逛</span></a></li></ul></div><div class="menus_item"><a class="site-page group" href="javascript:void(0);"><span>实验室</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/laboratory/"><i class="fa-fw fa fa-lightbulb-o"></i> <span>项目展示</span></a></li><li><a class="site-page child" href="/fcircle/"><i class="fa-fw fa fa-puzzle-piece"></i> <span>友链订阅</span></a></li><li><a class="site-page child" href="/charts/"><i class="fa-fw fa fa-pie-chart"></i> <span>博客统计</span></a></li><li><a class="site-page child" href="/update/"><i class="fa-fw fa fa-commenting-o"></i> <span>更新日志</span></a></li></ul></div><div class="menus_item"><a class="site-page group" href="javascript:void(0);"><span>清单</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/photos/"><i class="fa-fw fas fa-camera-retro"></i> <span>相册</span></a></li><li><a class="site-page child" href="/video/"><i class="fa-fw fa fa-video-camera"></i> <span>视频</span></a></li><li><a class="site-page child" href="/music/"><i class="fa-fw fas fa-music"></i> <span>歌单</span></a></li><li><a class="site-page child" href="/focus/"><i class="fa-fw fa fa-check-square-o"></i> <span>关注</span></a></li></ul></div><div class="menus_item"><a class="site-page group" href="javascript:void(0);"><span>关于</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/about/"><i class="fa-fw fas fa-user"></i> <span>本站</span></a></li><li><a class="site-page child" href="/love/"><i class="fa-fw fa fa-heart"></i> <span>Love</span></a></li><li><a class="site-page child" href="/donate/"><i class="fa-fw fa fa-gratipay"></i> <span>打赏</span></a></li></ul></div></div></div></div><div class="post" id="body-wrap"><header class="post-bg" id="page-header" style="background-image:url(https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2@latest/post/pytorch.jpg)"><nav id="nav"><div id="nav-group"><div id="blog_name"><a id="site-name" href="/">Justlovesmile</a></div><div id="menus"><div class="menus_items"><div class="menus_item"><a class="site-page group" href="javascript:void(0);"><span>网站</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/"><i class="fa-fw fas fa-home"></i> <span>首页</span></a></li><li><a class="site-page child" href="/guestbook/"><i class="fa-fw fas fa-pencil-alt"></i> <span>留言</span></a></li><li><a class="site-page child" href="/friends/"><i class="fa-fw fas fa-paper-plane"></i> <span>友链</span></a></li></ul></div><div class="menus_item"><a class="site-page group" href="javascript:void(0);"><span>文库</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/tags/"><i class="fa-fw fas fa-tags"></i> <span>全部标签</span></a></li><li><a class="site-page child" href="/categories/"><i class="fa-fw fas fa-folder-open"></i> <span>全部分类</span></a></li><li><a class="site-page child" href="/archives/"><i class="fa-fw fas fa-calendar"></i> <span>文章列表</span></a></li><li><a class="site-page child" href="/random/"><i class="fa-fw fas fa-shoe-prints"></i> <span>随便逛逛</span></a></li></ul></div><div class="menus_item"><a class="site-page group" href="javascript:void(0);"><span>实验室</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/laboratory/"><i class="fa-fw fa fa-lightbulb-o"></i> <span>项目展示</span></a></li><li><a class="site-page child" href="/fcircle/"><i class="fa-fw fa fa-puzzle-piece"></i> <span>友链订阅</span></a></li><li><a class="site-page child" href="/charts/"><i class="fa-fw fa fa-pie-chart"></i> <span>博客统计</span></a></li><li><a class="site-page child" href="/update/"><i class="fa-fw fa fa-commenting-o"></i> <span>更新日志</span></a></li></ul></div><div class="menus_item"><a class="site-page group" href="javascript:void(0);"><span>清单</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/photos/"><i class="fa-fw fas fa-camera-retro"></i> <span>相册</span></a></li><li><a class="site-page child" href="/video/"><i class="fa-fw fa fa-video-camera"></i> <span>视频</span></a></li><li><a class="site-page child" href="/music/"><i class="fa-fw fas fa-music"></i> <span>歌单</span></a></li><li><a class="site-page child" href="/focus/"><i class="fa-fw fa fa-check-square-o"></i> <span>关注</span></a></li></ul></div><div class="menus_item"><a class="site-page group" href="javascript:void(0);"><span>关于</span><i class="fas fa-chevron-down"></i></a><ul class="menus_item_child"><li><a class="site-page child" href="/about/"><i class="fa-fw fas fa-user"></i> <span>本站</span></a></li><li><a class="site-page child" href="/love/"><i class="fa-fw fa fa-heart"></i> <span>Love</span></a></li><li><a class="site-page child" href="/donate/"><i class="fa-fw fa fa-gratipay"></i> <span>打赏</span></a></li></ul></div></div></div><div id="nav-right"><div id="search-button"><a class="nav-rightbutton site-page social-icon search"><i class="fas fa-search fa-fw"></i></a></div><div id="darkmode_navswitch"><a class="nav-rightbutton site-page darkmode_switchbutton" onclick="switchDarkMode()"><i class="fas fa-adjust"></i></a></div><div id="toggle-menu"><a class="nav-rightbutton site-page"><i class="fas fa-bars fa-fw"></i></a></div></div></div></nav><div class="coverdiv" id="coverdiv"><img class="cover entered loading" id="post-cover" alt="cover" src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2@latest/post/pytorch.jpg"></div><div id="post-info"><div class="post-firstinfo" id="post-meta"><span class="post-meta-categories"><i class="fas fa-inbox fa-fw post-meta-icon"></i><a class="post-meta-categories" href="/categories/%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD/">人工智能</a></span><div class="post-meta__tag-list"><a class="post-meta__tags" href="/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/">#深度学习</a><a class="post-meta__tags" href="/tags/python/">#python</a><a class="post-meta__tags" href="/tags/pytorch/">#pytorch</a></div></div><h1 class="post-title">深度学习 | 《深度学习入门之PyTorch》阅读笔记</h1><div id="post-meta"><div class="meta-firstline"><span class="meta-share-time"><span class="meta-avatar"><a class="meta-avatar-img" href="/about/" title="关于作者"><img alt="作者头像" src="" data-lazy-src="/img/avatar.jpg"></a><a class="meta-avatar-name" href="/about/" title="关于作者">Justlovesmile</a></span></span><span class="post-meta-date"><i class="fa-fw post-meta-icon far fa-calendar-alt"></i><span class="post-meta-label">发表于</span><time datetime="2020-10-23T10:43:24.000Z" title="发表于 2020-10-23 18:43:24">2020-10-23</time></span></div><div class="meta-secondline"><span class="post-meta-separator">|</span><span class="post-meta-wordcount"><i class="far fa-file-word fa-fw post-meta-icon"></i><span class="post-meta-label">字数总计:</span><span class="word-count">26.5k</span><span class="post-meta-separator">|</span><i class="far fa-clock fa-fw post-meta-icon"></i><span class="post-meta-label">阅读时长:</span><span>115分钟</span></span></div></div></div></header><main class="layout" id="content-inner"><div id="post"><article class="post-content" id="article-container"><h1 id="深度学习入门之PyTorch"><a href="#深度学习入门之PyTorch" class="headerlink" title="深度学习入门之PyTorch"></a>深度学习入门之PyTorch</h1><h2 id="第一章-深度学习介绍"><a href="#第一章-深度学习介绍" class="headerlink" title="第一章 深度学习介绍"></a>第一章 深度学习介绍</h2><h3 id="1-1-人工智能"><a href="#1-1-人工智能" class="headerlink" title="1.1 人工智能"></a>1.1 人工智能</h3><ol><li>Artificial Intelligence，人工智能，也称机器智能。</li><li>人工智能分为三大类<br>（1）弱人工智能：擅长单方面<br>（2）强人工智能：类似人类等级<br>（3）超人工智能：全方面胜过人类</li></ol><h3 id="1-2-数据挖掘，机器学习和深度学习"><a href="#1-2-数据挖掘，机器学习和深度学习" class="headerlink" title="1.2 数据挖掘，机器学习和深度学习"></a>1.2 数据挖掘，机器学习和深度学习</h3><h4 id="1-2-1-数据挖掘"><a href="#1-2-1-数据挖掘" class="headerlink" title="1.2.1 数据挖掘"></a>1.2.1 数据挖掘</h4><p>KDD（knowledge discovery in database），从数据中获取有意义的信息</p><h4 id="1-2-2-机器学习"><a href="#1-2-2-机器学习" class="headerlink" title="1.2.2 机器学习"></a>1.2.2 机器学习</h4><ol><li>机器学习是实现人工智能的一种途径，涉及多门学科</li><li>大致分为五个大类<br>（1）监督学习：从给定的训练数据集中学习出一个函数，训练集中的目标是由人标注的，常见算法包括回归和分类<br>（2）无监督学习：训练集没有人为标注，常见算法如聚类<br>（3）半监督学习：介于两者之间<br>（4）迁移学习：将已经训练好的模型参数迁移到新的模型来帮助新模型训练数据集<br>（5）增强学习：通过观察周围环境来学习</li></ol><h4 id="1-2-3-深度学习"><a href="#1-2-3-深度学习" class="headerlink" title="1.2.3 深度学习"></a>1.2.3 深度学习</h4><ol><li>机器学习的一个分支，通过模拟人脑来实现数据特征的提取</li><li>常见网络结构：DNN，CNN，RNN，GAN等等</li></ol><h2 id="第二章-深度学习框架"><a href="#第二章-深度学习框架" class="headerlink" title="第二章 深度学习框架"></a>第二章 深度学习框架</h2><h3 id="2-1-深度学习框架介绍"><a href="#2-1-深度学习框架介绍" class="headerlink" title="2.1 深度学习框架介绍"></a>2.1 深度学习框架介绍</h3><ol><li>Tensorflow<br>Google开源的基于C++开发的数学计算软件</li><li>Caffe</li><li>Theano</li><li>Torch<br>支持动态图</li><li>MXNet</li></ol><h3 id="2-2-PyTorch介绍"><a href="#2-2-PyTorch介绍" class="headerlink" title="2.2 PyTorch介绍"></a>2.2 PyTorch介绍</h3><h4 id="2-2-1-什么是PyTorch"><a href="#2-2-1-什么是PyTorch" class="headerlink" title="2.2.1 什么是PyTorch"></a>2.2.1 什么是PyTorch</h4><p>Python优先的深度学习框架，支持GPU加速和动态神经网络</p><h4 id="2-2-2-为什么使用PyTorch"><a href="#2-2-2-为什么使用PyTorch" class="headerlink" title="2.2.2 为什么使用PyTorch"></a>2.2.2 为什么使用PyTorch</h4><p>1.多学习一个框架准没错<br>2.PyTorch通过一种反向自动求导的技术，可以让你零延迟地改变神经网络<br>3.线性，直观，易于使用<br>4.代码简洁直观，底层代码友好</p><h3 id="2-3-配置PyTorch深度学习环境"><a href="#2-3-配置PyTorch深度学习环境" class="headerlink" title="2.3 配置PyTorch深度学习环境"></a>2.3 配置PyTorch深度学习环境</h3><h4 id="2-3-1-操作系统"><a href="#2-3-1-操作系统" class="headerlink" title="2.3.1 操作系统"></a>2.3.1 操作系统</h4><p>Windows，Linux，Mac</p><h4 id="2-3-2-Python开发环境的安装"><a href="#2-3-2-Python开发环境的安装" class="headerlink" title="2.3.2 Python开发环境的安装"></a>2.3.2 Python开发环境的安装</h4><p>Anaconda</p><h4 id="2-3-3-PyTorch安装"><a href="#2-3-3-PyTorch安装" class="headerlink" title="2.3.3 PyTorch安装"></a>2.3.3 PyTorch安装</h4><p>官网或者anaconda</p><p>CPU或GPU</p><p>CUDA，CuDnn</p><h2 id="第三章-多层全连接神经网络"><a href="#第三章-多层全连接神经网络" class="headerlink" title="第三章 多层全连接神经网络"></a>第三章 多层全连接神经网络</h2><h3 id="3-1-PyTorch基础"><a href="#3-1-PyTorch基础" class="headerlink" title="3.1 PyTorch基础"></a>3.1 PyTorch基础</h3><h4 id="3-1-1-Tensor张量"><a href="#3-1-1-Tensor张量" class="headerlink" title="3.1.1 Tensor张量"></a>3.1.1 Tensor张量</h4><p>Tensor相当于多维的矩阵</p><p>Tensor的数据类型有：(32位浮点型)torch.FloatTensor，（64位浮点型）torch.DoubleTensor，（16位整型）torch.ShortTensor,（32位整型）torch.IntTensor，（64位整型）torch.LongTensor</p><p><strong>导入pytorch</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> __future__ <span class="keyword">import</span> print_function</span><br><span class="line"><span class="keyword">import</span> torch</span><br></pre></td></tr></table></figure><p><strong>创建一个没有初始化的5×3矩阵</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">x=torch.empty(<span class="number">5</span>,<span class="number">3</span>)</span><br><span class="line"><span class="built_in">print</span>(x)</span><br></pre></td></tr></table></figure><p><strong>创建一个随机初始化矩阵</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">#均匀分布[0,1],rand</span></span><br><span class="line">x=torch.rand(<span class="number">5</span>,<span class="number">3</span>)</span><br><span class="line"><span class="built_in">print</span>(x)</span><br><span class="line"></span><br><span class="line"><span class="comment">#正态分布，randn</span></span><br><span class="line">x=torch.randn(<span class="number">5</span>,<span class="number">3</span>)</span><br><span class="line"><span class="built_in">print</span>(x)</span><br></pre></td></tr></table></figure><p><strong>构造一个0矩阵，且数据类型为long</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">x=torch.zeros(<span class="number">5</span>,<span class="number">3</span>,dtype=torch.long)</span><br><span class="line"><span class="built_in">print</span>(x)</span><br></pre></td></tr></table></figure><p><strong>直接根据数据构造张量</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">x=torch.tensor([<span class="number">5.5</span>,<span class="number">3</span>])</span><br><span class="line"><span class="built_in">print</span>(x)</span><br></pre></td></tr></table></figure><p><strong>创建一个全为1的矩阵，且数据类型为double</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">x=torch.ones(<span class="number">5</span>,<span class="number">3</span>)</span><br><span class="line"><span class="built_in">print</span>(x)</span><br><span class="line"></span><br><span class="line">x=x.new_ones(<span class="number">5</span>,<span class="number">3</span>,dtype=torch.double)</span><br><span class="line"><span class="built_in">print</span>(x)</span><br></pre></td></tr></table></figure><p><strong>根据已有tensor建立新的tensor，且除非提供新值，将重用所给张量属性</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">x=x.new_ones(<span class="number">5</span>,<span class="number">3</span>,dtype=torch.double)</span><br><span class="line"><span class="built_in">print</span>(x)</span><br><span class="line"></span><br><span class="line">x=torch.randn_like(x,dtype=torch.<span class="built_in">float</span>)</span><br><span class="line"><span class="built_in">print</span>(x)</span><br></pre></td></tr></table></figure><p><strong>获取张量的形状</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">print</span>(x.size())</span><br></pre></td></tr></table></figure><blockquote><p><strong>注意</strong><br><code>torch.Size</code>本质上还是tuple，所以支持tuple的一切操作</p></blockquote><p><strong>和numpy的相互转换</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">print</span>(x)</span><br><span class="line">numpy_x = x.numpy()</span><br><span class="line"><span class="built_in">print</span>(numpy_x)</span><br><span class="line">torch_x = torch.from_numpy(numpy_x)</span><br><span class="line"><span class="built_in">print</span>(torch_x)</span><br></pre></td></tr></table></figure><p><strong>绝对值</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">a=torch.randn(<span class="number">2</span>,<span class="number">3</span>)</span><br><span class="line"><span class="built_in">print</span>(a)</span><br><span class="line"></span><br><span class="line">b=torch.<span class="built_in">abs</span>(a)</span><br><span class="line"><span class="built_in">print</span>(b)</span><br></pre></td></tr></table></figure><p><strong>运算</strong>，例如加法</p><p><strong>形式一</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">y=torch.rand(<span class="number">5</span>,<span class="number">3</span>)</span><br><span class="line"><span class="built_in">print</span>(x+y)</span><br></pre></td></tr></table></figure><p><strong>形式二</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">print</span>(torch.add(x,y))</span><br></pre></td></tr></table></figure><p><strong>形式三</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">result=torch.empty(<span class="number">5</span>,<span class="number">3</span>)</span><br><span class="line">torch.add(x,y,out=result)</span><br><span class="line"><span class="built_in">print</span>(result)</span><br></pre></td></tr></table></figure><p><strong>形式四</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">y.add_(x)</span><br><span class="line"><span class="built_in">print</span>(y)</span><br></pre></td></tr></table></figure><blockquote><p><strong>注意：</strong><br>任何一个in-place改变张量的操作后面都固定一个_。例如x.copy_(y)、x.t_()将更改x</p></blockquote><p><strong>剪裁</strong>:如果在上下边界内则不变，否则大于上边界值，则改为上边界值，小于下边界值，则改为下边界值</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">a=torch.randn(<span class="number">2</span>,<span class="number">3</span>)</span><br><span class="line"><span class="built_in">print</span>(a)</span><br><span class="line"></span><br><span class="line">b=torch.clamp(a,-<span class="number">0.1</span>,<span class="number">0.1</span>)</span><br><span class="line"><span class="built_in">print</span>(b)</span><br></pre></td></tr></table></figure><p><strong>除法</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">a=torch.randn(<span class="number">2</span>,<span class="number">3</span>)</span><br><span class="line">b=torch.randn(<span class="number">2</span>,<span class="number">3</span>)</span><br><span class="line">c=torch.div(a,b)</span><br><span class="line">d=torch.div(c,<span class="number">10</span>)</span><br><span class="line"><span class="built_in">print</span>(a)</span><br><span class="line"><span class="built_in">print</span>(b)</span><br><span class="line"><span class="built_in">print</span>(c)</span><br><span class="line"><span class="built_in">print</span>(d)</span><br></pre></td></tr></table></figure><blockquote><p><strong>加法</strong>add，<strong>乘积</strong>mul，<strong>除法</strong>div，<strong>求幂</strong>pow，<strong>矩阵乘法</strong>mm，<strong>矩阵向量乘法</strong>mv</p></blockquote><p><strong>改变张量形状</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">x=torch.randn(<span class="number">4</span>,<span class="number">4</span>)</span><br><span class="line">y=x.view(<span class="number">16</span>)</span><br><span class="line">z=x.view(-<span class="number">1</span>,<span class="number">8</span>) <span class="comment"># -1将会自动取值</span></span><br><span class="line"><span class="built_in">print</span>(x.size(),y.size(),z.size())</span><br></pre></td></tr></table></figure><p><strong>对于只含一个元素的tensor，可以使用<code>.item()</code>来得到数值</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">x=torch.randn(<span class="number">1</span>)</span><br><span class="line"><span class="built_in">print</span>(x)</span><br><span class="line"><span class="built_in">print</span>(x.item())</span><br></pre></td></tr></table></figure><p><strong>使用GPU</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">    device = torch.device(<span class="string">&quot;cuda&quot;</span>)</span><br><span class="line">    y = torch.ones_like(x, device=device)</span><br><span class="line">    x = x.to(device)</span><br><span class="line">    z = x+y</span><br><span class="line">    <span class="built_in">print</span>(z)</span><br><span class="line">    <span class="built_in">print</span>(z.to(<span class="string">&quot;CPU&quot;</span>,torch.double))</span><br></pre></td></tr></table></figure><h4 id="3-1-2-Variable（变量）"><a href="#3-1-2-Variable（变量）" class="headerlink" title="3.1.2 Variable（变量）"></a>3.1.2 Variable（变量）</h4><p><strong>1. Autograd：自动求导</strong></p><p><strong>创建一个张量并设置requires_grad=True用来追踪其计算历史</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">x=torch.ones(<span class="number">2</span>,<span class="number">2</span>,requires_grad=<span class="literal">True</span>)</span><br><span class="line"><span class="built_in">print</span>(x)</span><br></pre></td></tr></table></figure><p><strong>对这个张量做一次运算</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">y=x+<span class="number">2</span></span><br><span class="line"><span class="built_in">print</span>(y)</span><br><span class="line"><span class="comment"># y是计算结果，所以他有grad_fn属性</span></span><br><span class="line"><span class="built_in">print</span>(y.grad_fn)</span><br><span class="line"><span class="comment"># 对y进行更多操作</span></span><br><span class="line">z=y*y*<span class="number">3</span></span><br><span class="line">out=z.mean()</span><br><span class="line"><span class="built_in">print</span>(z,out)</span><br></pre></td></tr></table></figure><p>.requires_grad_(…) 改变了现有张量的 requires_grad 标志。如果没有指定的话，默认输入的这个标志是 False。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">a = torch.randn(<span class="number">2</span>, <span class="number">2</span>)</span><br><span class="line">a = ((a * <span class="number">3</span>) / (a - <span class="number">1</span>))</span><br><span class="line"><span class="built_in">print</span>(a.requires_grad)</span><br><span class="line">a.requires_grad_(<span class="literal">True</span>)</span><br><span class="line"><span class="built_in">print</span>(a.requires_grad)</span><br><span class="line">b = (a * a).<span class="built_in">sum</span>()</span><br><span class="line"><span class="built_in">print</span>(b.grad_fn)</span><br></pre></td></tr></table></figure><p><strong>2. 梯度</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">x=torch.ones(<span class="number">2</span>,<span class="number">2</span>,requires_grad=<span class="literal">True</span>)</span><br><span class="line">y=x+<span class="number">2</span></span><br><span class="line">z=y*y*<span class="number">3</span></span><br><span class="line">out=z.mean()</span><br><span class="line"><span class="comment"># 现在开始反向传播，因为out是一个标量，则out.backward()和out.backward(torch.tensor(1.))等价</span></span><br><span class="line">out.backward()</span><br><span class="line"><span class="comment">#输出导数d(out)/dx</span></span><br><span class="line"><span class="built_in">print</span>(x.grad)</span><br></pre></td></tr></table></figure><p>即</p><p>$$out=\frac{1}{4}\sum_iz_i$$</p><p>$$z_i=3(x_i+2)^2$$</p><p>并且$ z _ i| _ {x_i=1}=27$，因此，有</p><p>$$\frac{\partial_{out}}{\partial_{x_i}}=\frac{3}{2}(x_i+2)$$<br>因此<br>$$\frac{\partial _ {out}}{\partial_ {x_i}}|_ {x_i=1}=\frac{9}{2}=4.5$$</p><p><strong>雅可比矩阵</strong></p><p>数学上，若有向量值函数y=f(x)，那么y相当于对x的梯度是一个雅可比矩阵（下面是一个latex数学公式）</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">J=\begin&#123;bmatrix&#125;</span><br><span class="line">\frac&#123;\partial y_1&#125;&#123;\partial x_1&#125; &amp;\cdots&amp; \frac&#123;\partial y_1&#125;&#123;\partial x_n&#125; \\</span><br><span class="line">\vdots &amp; \ddots &amp; \vdots \\</span><br><span class="line">\frac&#123;\partial y_m&#125;&#123;\partial x_1&#125; &amp;\cdots&amp; \frac&#123;\partial y_m&#125;&#123;\partial x_n&#125;</span><br><span class="line">\end&#123;bmatrix&#125;</span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/image-20211125183936512.png" alt="image-20211125183936512"></p><p>通常来说，torch.autograd是计算雅可比向量积的一个引擎。也就是说，给定任意向量v，计算乘积$v^TJ$.如果v恰好是一个标量函数l=g(y)的导数，即$v=(\frac{\partial l}{\partial y_1} \cdots \frac{\partial l}{\partial y_m}^T)$，那么根据链式法则，雅可比向量积应该是l对x的导数</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line">J^T·v=\begin&#123;bmatrix&#125;</span><br><span class="line">\frac&#123;\partial y_1&#125;&#123;\partial x_1&#125; &amp;\cdots&amp; \frac&#123;\partial y_m&#125;&#123;\partial x_1&#125; \\</span><br><span class="line">\vdots &amp; \ddots &amp; \vdots \\</span><br><span class="line">\frac&#123;\partial y_1&#125;&#123;\partial x_n&#125; &amp;\cdots&amp; \frac&#123;\partial y_m&#125;&#123;\partial x_n&#125;</span><br><span class="line">\end&#123;bmatrix&#125;</span><br><span class="line">\begin&#123;bmatrix&#125;</span><br><span class="line">\frac&#123;\partial l&#125;&#123;\partial y_1&#125;\\</span><br><span class="line">\cdots\\</span><br><span class="line">\frac&#123;\partial l&#125;&#123;\partial y_m&#125;</span><br><span class="line">\end&#123;bmatrix&#125;=</span><br><span class="line">\begin&#123;bmatrix&#125;</span><br><span class="line">\frac&#123;\partial l&#125;&#123;\partial x_1&#125;\\</span><br><span class="line">\cdots\\</span><br><span class="line">\frac&#123;\partial l&#125;&#123;\partial x_n&#125;</span><br><span class="line">\end&#123;bmatrix&#125;</span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/image-20211125184143947.png" alt="image-20211125184143947"></p><p>(注意：行向量的$v^T⋅J$也可以被视作列向量的$J^T⋅v$)</p><p>雅可比向量积的这一特性使得将外部梯度输入到具有非标量输出的模型中变得非常方便。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">x=torch.randn(<span class="number">3</span>,requires_grad=<span class="literal">True</span>)</span><br><span class="line">y=x*<span class="number">2</span></span><br><span class="line"><span class="keyword">while</span> y.data.norm() &lt;<span class="number">1000</span>:</span><br><span class="line">    y=y*<span class="number">2</span></span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(y)</span><br></pre></td></tr></table></figure><p>在这种情况下，y 不再是标量。torch.autograd 不能直接计算完整的雅可比矩阵，但是如果我们只想要雅可比向量积，只需将这个向量作为参数传给 backward</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">v = torch.tensor([<span class="number">0.1</span>, <span class="number">1.0</span>, <span class="number">0.0001</span>], dtype=torch.<span class="built_in">float</span>)</span><br><span class="line">y.backward(v)</span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(x.grad)</span><br></pre></td></tr></table></figure><p>也可以通过将代码块包装在 with torch.no_grad(): 中，来阻止autograd跟踪设置了 .requires_grad=True 的张量的历史记录。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">print</span>(x.requires_grad)</span><br><span class="line"><span class="built_in">print</span>((x ** <span class="number">2</span>).requires_grad)</span><br><span class="line"></span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    <span class="built_in">print</span>((x ** <span class="number">2</span>).requires_grad)</span><br></pre></td></tr></table></figure><p><strong>3. Variable</strong></p><p>Variable和Tensor的区别，Variable会被放入计算图中，然后进行前向传播，反向传播，自动求导</p><p>Variable是在torch.autograd.Variable中</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"></span><br><span class="line">x=Variable(torch.Tensor([<span class="number">1</span>]),requires_grad=<span class="literal">True</span>)</span><br><span class="line">w=Variable(torch.Tensor([<span class="number">2</span>]),requires_grad=<span class="literal">True</span>)</span><br><span class="line">b=Variable(torch.Tensor([<span class="number">3</span>]),requires_grad=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">y=w*x+b</span><br><span class="line"></span><br><span class="line">y.backward()</span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(x.grad)</span><br><span class="line"><span class="built_in">print</span>(w.grad)</span><br><span class="line"><span class="built_in">print</span>(b.grad)</span><br></pre></td></tr></table></figure><p><strong>搭建一个简单的神经网络</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br></pre></td><td class="code"><pre><span class="line">batch_n = <span class="number">100</span> <span class="comment"># 一个批次中输入数据的数量</span></span><br><span class="line">hidden_layer = <span class="number">100</span> <span class="comment"># 经过隐藏层后保留的数据特征的个数</span></span><br><span class="line">input_data = <span class="number">1000</span> <span class="comment"># 每个数据包含的数据量</span></span><br><span class="line">output_data = <span class="number">10</span> <span class="comment"># 每个输出的数据包含的数据量</span></span><br><span class="line"></span><br><span class="line">x=torch.randn(batch_n,input_data) <span class="comment">#100*1000</span></span><br><span class="line">y=torch.randn(batch_n,output_data) <span class="comment">#100*10</span></span><br><span class="line"></span><br><span class="line">w1=torch.randn(input_data,hidden_layer) <span class="comment">#1000*100</span></span><br><span class="line">w2=torch.randn(hidden_layer,output_data) <span class="comment"># 100*10</span></span><br><span class="line"></span><br><span class="line">epoch_n = <span class="number">20</span> <span class="comment">#训练的次数</span></span><br><span class="line">learning_rate = <span class="number">1e-6</span> <span class="comment">#学习率</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(epoch_n):</span><br><span class="line">    h1=x.mm(w1)<span class="comment">#100*100</span></span><br><span class="line">    h1=h1.clamp(<span class="built_in">min</span>=<span class="number">0</span>) <span class="comment"># if x&lt;0 ,x=0</span></span><br><span class="line">    y_pred=h1.mm(w2) <span class="comment">#100*10，前向传播预测结果</span></span><br><span class="line">    </span><br><span class="line">    loss = (y_pred - y).<span class="built_in">pow</span>(<span class="number">2</span>).<span class="built_in">sum</span>() <span class="comment">#损失函数，即均方误差</span></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;Epoch:&#123;&#125;, Loss:&#123;:.4f&#125;&quot;</span>.<span class="built_in">format</span>(epoch,loss))</span><br><span class="line">    grad_y_pred = <span class="number">2</span>*(y_pred-y) <span class="comment">#dloss/dy</span></span><br><span class="line">    grad_w2 = h1.t().mm(grad_y_pred) <span class="comment">#dloss/dy * dy/dw2</span></span><br><span class="line">    </span><br><span class="line">    grad_h = grad_y_pred.clone() <span class="comment">#复制</span></span><br><span class="line">    grad_h = grad_h.mm(w2.t()) <span class="comment">#dloss/dy * dy/dh1</span></span><br><span class="line">    grad_h.clamp_(<span class="built_in">min</span>=<span class="number">0</span>) <span class="comment"># if x&lt;0 ,x=0</span></span><br><span class="line">    grad_w1 = x.t().mm(grad_h) </span><br><span class="line">    </span><br><span class="line">    w1 -= learning_rate*grad_w1 <span class="comment">#梯度下降</span></span><br><span class="line">    w2 -= learning_rate*grad_w2</span><br></pre></td></tr></table></figure><p><strong>使用Variable搭建一个自动计算梯度的神经网络</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"></span><br><span class="line">batch_n = <span class="number">100</span> <span class="comment"># 一个批次中输入数据的数量</span></span><br><span class="line">hidden_layer = <span class="number">100</span> <span class="comment"># 经过隐藏层后保留的数据特征的个数</span></span><br><span class="line">input_data = <span class="number">1000</span> <span class="comment"># 每个数据包含的数据量</span></span><br><span class="line">output_data = <span class="number">10</span> <span class="comment"># 每个输出的数据包含的数据量</span></span><br><span class="line"></span><br><span class="line">x=Variable(torch.randn(batch_n,input_data),requires_grad = <span class="literal">False</span>) <span class="comment">#requires_grad = False不保留梯度</span></span><br><span class="line">y=Variable(torch.randn(batch_n,output_data),requires_grad = <span class="literal">False</span>)</span><br><span class="line">w1=Variable(torch.randn(input_data,hidden_layer),requires_grad = <span class="literal">True</span>) <span class="comment">#requires_grad = True自动保留梯度</span></span><br><span class="line">w2=Variable(torch.randn(hidden_layer,output_data),requires_grad = <span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">epoch_n = <span class="number">20</span></span><br><span class="line">learning_rate = <span class="number">1e-6</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(epoch_n):</span><br><span class="line">    y_pred = x.mm(w1).clamp(<span class="built_in">min</span> = <span class="number">0</span>).mm(w2) <span class="comment">#y_pred=w2*(w1*x)</span></span><br><span class="line">    loss = (y_pred-y).<span class="built_in">pow</span>(<span class="number">2</span>).<span class="built_in">sum</span>() <span class="comment">#损失函数</span></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;Epoch:&#123;&#125;,Loss:&#123;:.4f&#125;&quot;</span>.<span class="built_in">format</span>(epoch,loss))</span><br><span class="line">    </span><br><span class="line">    loss.backward() <span class="comment">#后向传播计算</span></span><br><span class="line">    </span><br><span class="line">    w1.data -= learning_rate*w1.grad.data</span><br><span class="line">    w2.data -=learning_rate*w2.grad.data</span><br><span class="line">    </span><br><span class="line">    w1.grad.data.zero_() <span class="comment">#置0</span></span><br><span class="line">    w2.grad.data.zero_()</span><br></pre></td></tr></table></figure><p><strong>使用nn.Module自定义传播函数来搭建神经网络</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"></span><br><span class="line">batch_n = <span class="number">100</span></span><br><span class="line">hidden_layer = <span class="number">100</span></span><br><span class="line">input_data = <span class="number">1000</span></span><br><span class="line">output_data = <span class="number">10</span></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Model</span>(<span class="params">torch.nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(Model,self).__init__()</span><br><span class="line">    </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,input_n,w1,w2</span>):</span></span><br><span class="line">        x = torch.mm(input_n,w1)</span><br><span class="line">        x = torch.clamp(x,<span class="built_in">min</span>=<span class="number">0</span>)</span><br><span class="line">        x = torch.mm(x,w2)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line">    </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">backward</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="keyword">pass</span></span><br><span class="line">    </span><br><span class="line">model = Model()</span><br><span class="line"></span><br><span class="line">x=Variable(torch.randn(batch_n,input_data),requires_grad = <span class="literal">False</span>) <span class="comment">#requires_grad = False不保留梯度</span></span><br><span class="line">y=Variable(torch.randn(batch_n,output_data),requires_grad = <span class="literal">False</span>)</span><br><span class="line">w1=Variable(torch.randn(input_data,hidden_layer),requires_grad = <span class="literal">True</span>) <span class="comment">#requires_grad = True自动保留梯度</span></span><br><span class="line">w2=Variable(torch.randn(hidden_layer,output_data),requires_grad = <span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">epoch_n = <span class="number">20</span></span><br><span class="line">learning_rate = <span class="number">1e-6</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(epoch_n):</span><br><span class="line">    y_pred = model(x,w1,w2)</span><br><span class="line">    loss = (y_pred-y).<span class="built_in">pow</span>(<span class="number">2</span>).<span class="built_in">sum</span>()</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;Epoch:&#123;&#125;,Loss:&#123;:.4f&#125;&quot;</span>.<span class="built_in">format</span>(epoch,loss))</span><br><span class="line">    loss.backward() <span class="comment">#后向传播计算</span></span><br><span class="line">    </span><br><span class="line">    w1.data -= learning_rate*w1.grad.data</span><br><span class="line">    w2.data -=learning_rate*w2.grad.data</span><br><span class="line">    </span><br><span class="line">    w1.grad.data.zero_() <span class="comment">#置0</span></span><br><span class="line">    w2.grad.data.zero_()</span><br></pre></td></tr></table></figure><h4 id="3-1-3-Dataset-数据集"><a href="#3-1-3-Dataset-数据集" class="headerlink" title="3.1.3 Dataset(数据集)"></a>3.1.3 Dataset(数据集)</h4><p>torch.utils.data.Dataset是代表这一数据的抽象类，可以自己定义数据类继承和重写这个抽象类，只需要定义<code>__len__</code>和<code>__getitem__</code>函数即可</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> Dataset</span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">myDataset</span>(<span class="params">Dataset</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, csv_file, txt_file, root_dir, other_file</span>):</span></span><br><span class="line">        self.csv_data = pd.read_csv(csv_file)</span><br><span class="line">        <span class="keyword">with</span> <span class="built_in">open</span>(txt_file, <span class="string">&#x27;r&#x27;</span>) <span class="keyword">as</span> f:</span><br><span class="line">            data_list=f.readlines()</span><br><span class="line">        self.txt_data = data_list</span><br><span class="line">        self.root_dir = root_dir</span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__len__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="keyword">return</span> <span class="built_in">len</span>(self.csv_data)</span><br><span class="line">    </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__getitem__</span>(<span class="params">self,idx</span>):</span></span><br><span class="line">        data = (self.csv_data[idx],self.txt_data[idx])</span><br><span class="line">        <span class="keyword">return</span> data</span><br></pre></td></tr></table></figure><p>通过上面的方式，可以定义需要的数据类，可以通过迭代的方法取得每一个数据，但是这样很难实现取batch，shuffle或者多线程去读取数据，所以Pytorch中提供了torch.utils.data.DataLoader来定义一个新迭代器</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line">dataiter = DataLoader(myDataset,batch_size=<span class="number">32</span>)</span><br></pre></td></tr></table></figure><h4 id="3-1-4-nn-Module-模组"><a href="#3-1-4-nn-Module-模组" class="headerlink" title="3.1.4 nn.Module(模组)"></a>3.1.4 nn.Module(模组)</h4><p>所有的层结构和损失函数来自torch.nn</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">net_name</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self,other_arguments</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(net_name, self).__init__()</span><br><span class="line">        self.conv1 = nn.Conv2d(in_channels,out_channels, kernel_size)</span><br><span class="line">    </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        x = self.conv1(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br></pre></td></tr></table></figure><p>一个神经网络的典型训练过程如下：</p><ul><li>定义包含一些可学习参数(或者叫权重）的神经网络</li><li>在输入数据集上迭代</li><li>通过网络处理输入</li><li>计算loss(输出和正确答案的距离）</li><li>将梯度反向传播给网络的参数</li><li>更新网络的权重，一般使用一个简单的规则：weight = weight - learning_rate * gradient</li></ul><p><strong>使用torch.nn内的序列容器Sequential</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line">batch_n = <span class="number">100</span></span><br><span class="line">hidden_layer = <span class="number">100</span></span><br><span class="line">input_data = <span class="number">1000</span></span><br><span class="line">output_data = <span class="number">10</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 第一种方式</span></span><br><span class="line">models_1 = torch.nn.Sequential(</span><br><span class="line">    torch.nn.Linear(input_data,hidden_layer),</span><br><span class="line">    torch.nn.ReLU(),</span><br><span class="line">    torch.nn.Linear(hidden_layer,output_data)</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 第二种方式</span></span><br><span class="line"><span class="keyword">from</span> collections <span class="keyword">import</span> OrderedDict</span><br><span class="line">models_2 = torch.nn.Sequential(OrderedDict([</span><br><span class="line">    (<span class="string">&quot;Line1&quot;</span>,torch.nn.Linear(input_data,hidden_layer)),</span><br><span class="line">    (<span class="string">&quot;ReLU1&quot;</span>,torch.nn.ReLU()),</span><br><span class="line">    (<span class="string">&quot;Line2&quot;</span>,torch.nn.Linear(hidden_layer,output_data))])    </span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(models_1)</span><br><span class="line"><span class="built_in">print</span>(models_2)</span><br></pre></td></tr></table></figure><p><strong>使用nn.Module定义一个神经网络</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torch.nn <span class="keyword">as</span> nn</span><br><span class="line"><span class="keyword">import</span> torch.nn.functional <span class="keyword">as</span> F</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Net</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(Net, self).__init__()</span><br><span class="line">        <span class="comment"># 输入图像channel：1；输出channel：6；5x5卷积核</span></span><br><span class="line">        self.conv1 = nn.Conv2d(<span class="number">1</span>, <span class="number">6</span>, <span class="number">5</span>)</span><br><span class="line">        self.conv2 = nn.Conv2d(<span class="number">6</span>, <span class="number">16</span>, <span class="number">5</span>)</span><br><span class="line">        <span class="comment"># an affine operation: y = Wx + b</span></span><br><span class="line">        self.fc1 = nn.Linear(<span class="number">16</span> * <span class="number">5</span> * <span class="number">5</span>, <span class="number">120</span>)</span><br><span class="line">        self.fc2 = nn.Linear(<span class="number">120</span>, <span class="number">84</span>)</span><br><span class="line">        self.fc3 = nn.Linear(<span class="number">84</span>, <span class="number">10</span>)</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        <span class="comment"># 2x2 Max pooling</span></span><br><span class="line">        x = F.max_pool2d(F.relu(self.conv1(x)), (<span class="number">2</span>, <span class="number">2</span>))</span><br><span class="line">        <span class="comment"># 如果是方阵,则可以只使用一个数字进行定义</span></span><br><span class="line">        x = F.max_pool2d(F.relu(self.conv2(x)), <span class="number">2</span>)</span><br><span class="line">        x = x.view(-<span class="number">1</span>, self.num_flat_features(x))</span><br><span class="line">        x = F.relu(self.fc1(x))</span><br><span class="line">        x = F.relu(self.fc2(x))</span><br><span class="line">        x = self.fc3(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">num_flat_features</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        size = x.size()[<span class="number">1</span>:]  <span class="comment"># 除去批处理维度的其他所有维度</span></span><br><span class="line">        num_features = <span class="number">1</span></span><br><span class="line">        <span class="keyword">for</span> s <span class="keyword">in</span> size:</span><br><span class="line">            num_features *= s</span><br><span class="line">        <span class="keyword">return</span> num_features</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">net = Net()</span><br><span class="line"><span class="built_in">print</span>(net)</span><br></pre></td></tr></table></figure><h4 id="3-1-5-torch-optim-优化"><a href="#3-1-5-torch-optim-优化" class="headerlink" title="3.1.5 torch.optim(优化)"></a>3.1.5 torch.optim(优化)</h4><p>优化算法分为两大类：</p><p>（1）一阶优化算法<br>使用各个参数的梯度值来更新参数，最常用的是梯度下降。梯度下降的功能是通过寻找最小值，控制方差，更新模型参数，最终使模型收敛，网络的参数更新公式<br>$$\theta = \theta - \eta × \frac{\partial J(\theta)}{\partial \theta}$$<br>其中$\eta$是学习率，$\frac{\partial J(\theta)}{\partial \theta}$是函数的梯度</p><p>（2）二阶优化算法<br>二阶优化算法使用了二阶导数（Hessian方法）来最小化或最大化损失函数，主要是基于牛顿法</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">optimizer=torch.optim.SGD(model.parameters(),lr=<span class="number">0.01</span>,momentum=<span class="number">0.9</span>)</span><br></pre></td></tr></table></figure><h4 id="3-1-6-模型的保存和加载"><a href="#3-1-6-模型的保存和加载" class="headerlink" title="3.1.6 模型的保存和加载"></a>3.1.6 模型的保存和加载</h4><p>1.保存</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">#保存模型</span></span><br><span class="line">torch.save(model,path)</span><br><span class="line"><span class="comment">#保存模型的状态</span></span><br><span class="line">torch.save(model.state_dict(),path)</span><br></pre></td></tr></table></figure><p>2.加载</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">#加载完整的模型</span></span><br><span class="line">load_model = torch.load(path)</span><br><span class="line"><span class="comment">#加载模型参数，需要先导入模型的结构</span></span><br><span class="line">model.load_state_dic(torch.load(path))</span><br></pre></td></tr></table></figure><h3 id="3-2-线性模型"><a href="#3-2-线性模型" class="headerlink" title="3.2 线性模型"></a>3.2 线性模型</h3><h4 id="3-2-1-介绍"><a href="#3-2-1-介绍" class="headerlink" title="3.2.1 介绍"></a>3.2.1 介绍</h4><p>f(x)=wx+b</p><p>f(x)=w1x1+w2x2+…+wdxd+b</p><p>w和b都是需要学习的参数</p><h4 id="3-2-2-一维线性回归"><a href="#3-2-2-一维线性回归" class="headerlink" title="3.2.2 一维线性回归"></a>3.2.2 一维线性回归</h4><p>给定数据集D={(x1,y1),(x2,y2),…,(xm,ym)}，线性回归希望得到一个f(x)=wx+b能够很好的拟合y</p><p>方法是利用$Loss=\sum_{i=1}^m(f(x_i)-y_i)^2$来衡量误差，即均方误差，那么<br>$$(w^*,b^*)=arg\min_{w,b}\sum_{i=1}^m(f(x_i)-y_i)^2=arg\min_{w,b}\sum_{i=1}^m(y_i-wx_i-b)^2$$</p><p>求解办法：求它的偏导数,并让其为0来估计参数<br>$$\frac{\partial Loss_{(w,b)}}{\partial w} = 2(w\sum_{i=1}^{m}x_i^2-\sum_{i=1}^{m}(y_i-b)x_i)=0$$<br>$$\frac{\partial Loss_{(w,b)}}{\partial b} = 2(mb-\sum_{i=1}^{m}(y_i-wx_i))=0$$<br>得到w和b的最优解<br>$$w=\frac{\sum_{i=1}^{m}y_i(x_i- \bar x)}{\sum_{i=1}^{m}x_i^2-\frac{1}{m}(\sum_{i=1}^{m}x_i)^2}$$<br>$$b=\frac{1}{m}\sum_{i=1}^{m}(y_i-wx_i)$$<br>其中$\bar x$是x的均值<br>$$\bar x = \frac{1}{m}\sum_{i=1}^{m}x_i$$</p><h4 id="3-2-3-多维线性回归"><a href="#3-2-3-多维线性回归" class="headerlink" title="3.2.3 多维线性回归"></a>3.2.3 多维线性回归</h4><p>$$f(x_i)=w^Tx_i+b$$<br>为使得$\sum_{i=1}^{m}(f(x_i)-y_i)^2$最小，这也称为“多元线性回归”，使用最小二乘法对w和b进行估计，假设有d个属性，将w和d写入同一个矩`阵，将数据集D表示成一个m×(d+1)的矩阵X，即</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">X=\begin&#123;bmatrix&#125;</span><br><span class="line">x_&#123;11&#125; &amp; x_&#123;12&#125; &amp; \cdots &amp; x_&#123;1d&#125; &amp; 1 \\</span><br><span class="line">x_&#123;21&#125; &amp; x_&#123;22&#125; &amp; \cdots &amp; x_&#123;2d&#125; &amp; 1 \\</span><br><span class="line">\vdots &amp; \vdots &amp; \ddots &amp; \vdots &amp; \vdots \\</span><br><span class="line">x_&#123;m1&#125; &amp; x_&#123;m2&#125; &amp; \cdots &amp; x_&#123;md&#125; &amp; 1</span><br><span class="line">\end&#123;bmatrix&#125;=</span><br><span class="line">\begin&#123;bmatrix&#125;</span><br><span class="line">x_1^T &amp; 1\\</span><br><span class="line">x_2^T &amp; 1\\</span><br><span class="line">\vdots &amp; \vdots\\</span><br><span class="line">x_m^T &amp; 1</span><br><span class="line">\end&#123;bmatrix&#125;</span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/image-20211125184508266.png" alt="image-20211125184508266"></p><p>将目标y也写成乘向量的形式y=(y1,y2,…,ym),那么可得<br>$$w^* = arg \min_w(y-Xw)^T(y-Xw)$$<br>对其求导，令它为0<br>$$\frac{\partial Loss_w}{\partial w}=2X^T(Xw-y)=0$$</p><blockquote><p>上面涉及到矩阵的逆运算，所以需要$X^TX$是一个满秩矩阵或者正定矩阵</p></blockquote><p>可以得到:<br>$$w ^ * =(X^TX)^{-1}X^Ty$$<br>故回归模型可以写成：<br>$$f(x _ i)=x _ i^T(X^TX)^{-1}X^Ty$$</p><h4 id="3-2-4-一维线性回归的代码实现"><a href="#3-2-4-一维线性回归的代码实现" class="headerlink" title="3.2.4 一维线性回归的代码实现"></a>3.2.4 一维线性回归的代码实现</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line"></span><br><span class="line">x_train = np.array([[<span class="number">3.3</span>],[<span class="number">4.4</span>],[<span class="number">5.5</span>],[<span class="number">6.71</span>],[<span class="number">6.93</span>],[<span class="number">4.168</span>],[<span class="number">9.779</span>],[<span class="number">6.182</span>],[<span class="number">7.59</span>],[<span class="number">2.167</span>],[<span class="number">7.042</span>],[<span class="number">10.791</span>],[<span class="number">5.313</span>],[<span class="number">7.997</span>],[<span class="number">3.1</span>]],dtype=np.float32)</span><br><span class="line">y_train = np.array([[<span class="number">1.7</span>],[<span class="number">2.76</span>],[<span class="number">2.09</span>],[<span class="number">3.19</span>],[<span class="number">1.694</span>],[<span class="number">1.573</span>],[<span class="number">3.366</span>],[<span class="number">2.596</span>],[<span class="number">2.53</span>],[<span class="number">1.221</span>],[<span class="number">2.827</span>],[<span class="number">3.465</span>],[<span class="number">1.65</span>],[<span class="number">2.904</span>],[<span class="number">1.3</span>]],dtype=np.float32)</span><br><span class="line"></span><br><span class="line">x_train = torch.from_numpy(x_train)</span><br><span class="line">y_train = torch.from_numpy(y_train)</span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">LinearRegression</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(LinearRegression,self).__init__() <span class="comment">#继承父类</span></span><br><span class="line">        self.linear = nn.Linear(<span class="number">1</span>,<span class="number">1</span>) <span class="comment"># 1*1</span></span><br><span class="line">    </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        out=self.linear(x)</span><br><span class="line">        <span class="keyword">return</span> out</span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">    model = LinearRegression().cuda()</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line">    model = LinearRegression()</span><br><span class="line"></span><br><span class="line">criterion = torch.nn.MSELoss() <span class="comment"># 均方误差</span></span><br><span class="line"><span class="comment">#优化函数，model.parameters()为该实例中可优化的参数，lr为参数优化的选项（学习率等）</span></span><br><span class="line">optimizer = torch.optim.SGD(model.parameters(),lr=<span class="number">1e-3</span>) <span class="comment">#梯度下降</span></span><br><span class="line"></span><br><span class="line">num_epochs = <span class="number">1000</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(num_epochs):</span><br><span class="line">    <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">        inputs = Variable(x_train).cuda()</span><br><span class="line">        target = Variable(y_train).cuda()</span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        inputs = Variable(x_train)</span><br><span class="line">        target = Variable(y_train)</span><br><span class="line">    <span class="comment"># forward</span></span><br><span class="line">    out = model(inputs)</span><br><span class="line">    loss = criterion(out,target) <span class="comment">#均方误差</span></span><br><span class="line">    <span class="comment"># backward</span></span><br><span class="line">    optimizer.zero_grad() <span class="comment">#置0</span></span><br><span class="line">    loss.backward() <span class="comment">#求梯度</span></span><br><span class="line">    optimizer.step() <span class="comment">#更新所有的参数，梯度下降</span></span><br><span class="line">    </span><br><span class="line">    <span class="keyword">if</span>(epoch+<span class="number">1</span>)%<span class="number">50</span>==<span class="number">0</span>:</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;Epoch[&#123;&#125;/&#123;&#125;],Loss:&#123;:.6f&#125;&#x27;</span>.<span class="built_in">format</span>(epoch+<span class="number">1</span>,num_epochs,loss))</span><br><span class="line"></span><br><span class="line">model.<span class="built_in">eval</span>() <span class="comment">#将模型变成测试模式</span></span><br><span class="line">predict = model(Variable(x_train))</span><br><span class="line">predict = predict.data.numpy()</span><br><span class="line"><span class="comment">#画图</span></span><br><span class="line"><span class="comment">#plt.plot(x_train.numpy(),y_train.numpy(),&#x27;ro&#x27;,label=&#x27;Original data&#x27;)</span></span><br><span class="line"><span class="comment">#plt.plot(x_train.numpy(),predict,label=&quot;Fitting Line&quot;)</span></span><br><span class="line"><span class="comment">#plt.show()</span></span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20201022195143.png"></p><h4 id="3-2-5-多项式回归"><a href="#3-2-5-多项式回归" class="headerlink" title="3.2.5 多项式回归"></a>3.2.5 多项式回归</h4><p>对于$y=b+w_1×x+w_2×x^2+w_3×x^3$，预处理数据，变成矩阵形式</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">X=\begin&#123;bmatrix&#125;</span><br><span class="line">x_1 &amp; x_1^2 &amp; x_1^3 \\</span><br><span class="line">x_2 &amp; x_2^2 &amp; x_2^3 \\</span><br><span class="line">\vdots &amp; \ddots &amp; \vdots \\</span><br><span class="line">x_n &amp; x_n^2 &amp; x_n^3</span><br><span class="line">\end&#123;bmatrix&#125;</span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/image-20211125184611404.png" alt="image-20211125184611404"></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">make_features</span>(<span class="params">x</span>):</span></span><br><span class="line">    x=x.unsqueeze(<span class="number">1</span>)  <span class="comment"># 在第1维（从0开始）增加一维</span></span><br><span class="line">    <span class="keyword">return</span> torch.cat([x ** i <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">1</span>,<span class="number">4</span>)],<span class="number">1</span>) <span class="comment">#1代表横着拼接x,x^2,x^3</span></span><br><span class="line"></span><br><span class="line">w_target = torch.FloatTensor([<span class="number">0.5</span>,<span class="number">3</span>,<span class="number">2.4</span>]).unsqueeze(<span class="number">1</span>) <span class="comment"># 在第1维（从0开始）加一层</span></span><br><span class="line">b_target = torch.FloatTensor([<span class="number">0.9</span>])</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">f</span>(<span class="params">x</span>):</span></span><br><span class="line">    <span class="comment">#定义∑wix^i+b</span></span><br><span class="line">    <span class="keyword">return</span> x.mm(w_target) + b_target[<span class="number">0</span>]</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">get_batch</span>(<span class="params">batch_size=<span class="number">32</span></span>):</span></span><br><span class="line">    <span class="comment">#产生数据</span></span><br><span class="line">    random = torch.randn(batch_size)</span><br><span class="line">    x = make_features(random)</span><br><span class="line">    y = f(x)</span><br><span class="line">    <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">        <span class="keyword">return</span> Variable(x).cuda(),Variable(y).cuda()</span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        <span class="keyword">return</span> Variable(x),Variable(y)</span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">poly_model</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(poly_model,self).__init__()</span><br><span class="line">        self.poly = nn.Linear(<span class="number">3</span>,<span class="number">1</span>)</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        out = self.poly(x)</span><br><span class="line">        <span class="keyword">return</span> out</span><br><span class="line">    </span><br><span class="line"><span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">    model = poly_model().cuda()</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line">    model = poly_model()</span><br><span class="line">    </span><br><span class="line">criterion = nn.MSELoss() <span class="comment"># 均方误差</span></span><br><span class="line">optimizer = torch.optim.SGD(model.parameters(),lr=<span class="number">1e-3</span>)<span class="comment">#梯度下降</span></span><br><span class="line"></span><br><span class="line">epoch = <span class="number">0</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">while</span> <span class="literal">True</span>:</span><br><span class="line">    batch_x,batch_y = get_batch()</span><br><span class="line">    <span class="comment">#前向传播</span></span><br><span class="line">    output = model(batch_x)</span><br><span class="line">    loss = criterion(output,batch_y)</span><br><span class="line">    epoch+=<span class="number">1</span></span><br><span class="line">    <span class="keyword">if</span> epoch%<span class="number">50</span> ==<span class="number">0</span>:</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&quot;Epoch:&#123;&#125;,Loss:&#123;:.6f&#125;&quot;</span>.<span class="built_in">format</span>(epoch,loss.data.item()))</span><br><span class="line">    optimizer.zero_grad() <span class="comment">#置0</span></span><br><span class="line">    loss.backward() <span class="comment">#后向传播</span></span><br><span class="line">    optimizer.step() <span class="comment">#优化参数</span></span><br><span class="line">    </span><br><span class="line">    <span class="keyword">if</span> loss &lt;<span class="number">1e-2</span>:</span><br><span class="line">        <span class="keyword">break</span></span><br><span class="line">    </span><br></pre></td></tr></table></figure><blockquote><p>注意：<br><code>torch.nn</code>只支持小批量处理<code>(mini-batches）</code>。整个<code>torch.nn</code>包只支持小批量样本的输入，不支持单个样本的输入。<br>比如，<code>nn.Conv2d</code> 接受一个4维的张量，即<code>nSamples x nChannels x Height x Width</code>.<br>如果是一个单独的样本，只需要使用<code>input.unsqueeze(0)</code>来添加一个“假的”批大小维度。</p></blockquote><h3 id="3-3-分类问题"><a href="#3-3-分类问题" class="headerlink" title="3.3 分类问题"></a>3.3 分类问题</h3><h4 id="3-3-1-问题介绍"><a href="#3-3-1-问题介绍" class="headerlink" title="3.3.1 问题介绍"></a>3.3.1 问题介绍</h4><p>机器学习中的监督学习主要分为回归问题和分类问题，对于回归问题，希望预测的结果是连续的，对于分类问题所预测的结果是离散的。</p><p>监督学习从数据中学习一个分类模型或者分类决策函数，被称为分类器</p><h4 id="3-3-2-Logistic起源"><a href="#3-3-2-Logistic起源" class="headerlink" title="3.3.2 Logistic起源"></a>3.3.2 Logistic起源</h4><p>著名的二分类算法，Logistic回归。起源于对人口数量增长情况的研究</p><h4 id="3-3-3-Logistic分布"><a href="#3-3-3-Logistic分布" class="headerlink" title="3.3.3 Logistic分布"></a>3.3.3 Logistic分布</h4><p>设x是连续的随机变量，服从Logistic分布是指X的分布函数和密度函数是如下<br>$$F(x)=P(X≤x)=\frac{1}{1+e^{-(x-\mu)/\gamma}}$$<br>$$f(x)=\frac{e^{-(x-\mu)/\gamma}}{\gamma (1+e^{-(x-\mu)/\gamma})^2}$$<br>其中μ影响中心对称点的位置，γ越小中心点附件的增长速度越快<br>Sigmoid函数是Logistic分布函数中γ=1，μ=0的特殊形式，表达式如下：$$p(x)=\frac{1}{1+e^{-x}}$$</p><h4 id="3-3-4-二分类的Logistic回归"><a href="#3-3-4-二分类的Logistic回归" class="headerlink" title="3.3.4 二分类的Logistic回归"></a>3.3.4 二分类的Logistic回归</h4><p>假设输入的数据的特征向量$x∈R^n$，那么决策边界可以表示为$\sum_{i=1}^{n}w_ix_i+b=0$，建设存在一个样本点使得$h_w(x)=\sum_{i=1}^{n}w_ix_i+b&gt;0$，那么判定它的类别是1，如果&lt;0，判定其类别是0.<br>Logistic回归通过找到分类概率P(Y=1)与输入变量x的直接关系，然后通过比较概率值来判断类别，简单来说就是通过计算下面两个概率分布<br>$$P(Y=0|x)=\frac{1}{1+e^{wx+b}}$$<br>$$P(Y=1|x)=\frac{e^{wx+b}}{1+e^{wx+b}}$$<br>其中w是权重，b是偏置</p><blockquote><p>一个事件发生的几率（odds）是指该事件发生的概率（p）与不发生的概率的比值（1-p），该事件的对数几率或logit函数是：$logit(p)=log\frac{p}{1-p}$</p></blockquote><p>对于Logistic回归而言，可以得到：<br>$$log \frac{P(Y=1|x)}{1-P(Y=1|x)}=wx+b$$</p><h4 id="3-3-5-模型的参数估计"><a href="#3-3-5-模型的参数估计" class="headerlink" title="3.3.5 模型的参数估计"></a>3.3.5 模型的参数估计</h4><p>对于给定的训练集数据T={(x1,y1),(x2,y2),…,(xn,yn)}，其中$x_i \in R^n,y_i \in ${0,1}，假设P(Y=1|x)=Π(x)，那么P(Y=0|x)=1-Π(x)，所以似然函数为：<br>$$\prod_{i=1}^{n}[\pi (x_i)]^{y_1}[1-\pi (x_i)]^{1-y_i}$$<br>取对数后的对数似然函数：<br>$$L(w)=\sum_{i=1}^{n}[y_i(wx_i+b)-log(1+e^{wx_i+b})]$$<br>用L(w)对w求导：<br>$$\frac{\partial L(w)}{\partial w}=\sum_{i=1}^{n}y_ix_i-\sum_{i=1}^{n}\frac{e^{wx_i+b}}{1+e^{wx_i+b}}x_i=\sum_{i=1}^{n}(y_i-logit(wx_i))x_i$$<br>$$\frac{\partial L(w)}{\partial b}=\sum_{i=1}^{n}y_i-\sum_{i=1}^{n}\frac{e^{wx_i+b}}{1+e^{wx_i+b}}=\sum_{i=1}^{n}(y_i-logit(wx_i))$$</p><h4 id="3-3-6-Logistic回归的代码实现"><a href="#3-3-6-Logistic回归的代码实现" class="headerlink" title="3.3.6 Logistic回归的代码实现"></a>3.3.6 Logistic回归的代码实现</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> requests</span><br><span class="line"></span><br><span class="line"><span class="comment">#获取数据</span></span><br><span class="line">url=<span class="string">&quot;https://cdn.jsdelivr.net/gh/Justlovesmile/code-of-learn-deep-learning-with-pytorch/chapter3_NN/logistic-regression/data.txt&quot;</span></span><br><span class="line">data = requests.get(url)</span><br><span class="line">data_list=data.text.split(<span class="string">&#x27;\n&#x27;</span>)[:-<span class="number">1</span>]</span><br><span class="line">data_list=[i.split(<span class="string">&#x27;,&#x27;</span>) <span class="keyword">for</span> i <span class="keyword">in</span> data_list]</span><br><span class="line">data = [(<span class="built_in">float</span>(i[<span class="number">0</span>]),<span class="built_in">float</span>(i[<span class="number">1</span>]),<span class="built_in">float</span>(i[<span class="number">2</span>])) <span class="keyword">for</span> i <span class="keyword">in</span> data_list]</span><br><span class="line"></span><br><span class="line">np_data = np.array(data, dtype=<span class="string">&#x27;float32&#x27;</span>) <span class="comment"># 转换成 numpy array</span></span><br><span class="line">x_data = torch.from_numpy(np_data[:, <span class="number">0</span>:<span class="number">2</span>]) <span class="comment"># 转换成 Tensor, 大小是 [100, 2]</span></span><br><span class="line">y_data = torch.from_numpy(np_data[:, -<span class="number">1</span>]).unsqueeze(<span class="number">1</span>) <span class="comment"># 转换成 Tensor，大小是 [100, 1]</span></span><br><span class="line"></span><br><span class="line"><span class="comment">#print(x_data,y_data)</span></span><br><span class="line"></span><br><span class="line"><span class="comment">#画数据的散点图</span></span><br><span class="line">x0=<span class="built_in">list</span>(<span class="built_in">filter</span>(<span class="keyword">lambda</span> x:x[-<span class="number">1</span>]==<span class="number">0.0</span>,data))</span><br><span class="line">x1=<span class="built_in">list</span>(<span class="built_in">filter</span>(<span class="keyword">lambda</span> x:x[-<span class="number">1</span>]==<span class="number">1.0</span>,data))</span><br><span class="line">plot_x0_0 = [i[<span class="number">0</span>] <span class="keyword">for</span> i <span class="keyword">in</span> x0]</span><br><span class="line">plot_x0_1 = [i[<span class="number">1</span>] <span class="keyword">for</span> i <span class="keyword">in</span> x0]</span><br><span class="line">plot_x1_0 = [i[<span class="number">0</span>] <span class="keyword">for</span> i <span class="keyword">in</span> x1]</span><br><span class="line">plot_x1_1 = [i[<span class="number">1</span>] <span class="keyword">for</span> i <span class="keyword">in</span> x1]</span><br><span class="line"></span><br><span class="line">plt.plot(plot_x0_0,plot_x0_1,<span class="string">&#x27;ro&#x27;</span>,label=<span class="string">&quot;x_0&quot;</span>) <span class="comment">#0类用红色</span></span><br><span class="line">plt.plot(plot_x1_0,plot_x1_1,<span class="string">&#x27;bo&#x27;</span>,label=<span class="string">&quot;x_1&quot;</span>) <span class="comment">#1类用蓝色</span></span><br><span class="line">plt.legend(loc=<span class="string">&#x27;best&#x27;</span>) <span class="comment">#图例的位置</span></span><br><span class="line"></span><br><span class="line"><span class="comment">#分类</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">LogisticRegression</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(LogisticRegression,self).__init__() <span class="comment">#继承</span></span><br><span class="line">        self.lr = nn.Linear(<span class="number">2</span>,<span class="number">1</span>) <span class="comment">#2*1</span></span><br><span class="line">        self.sm = nn.Sigmoid() <span class="comment">#sigmoid函数</span></span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        x=self.lr(x)</span><br><span class="line">        x=self.sm(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line">    </span><br><span class="line">logistic_model = LogisticRegression()</span><br><span class="line"><span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">    logistic_model.cuda()</span><br><span class="line"></span><br><span class="line">criterion = nn.BCELoss() <span class="comment">#二分类的损失函数</span></span><br><span class="line"><span class="comment">#随机梯度下降优化，parameters是可优化参数，lr是学习率，momentum是动量因子</span></span><br><span class="line">optimizer = torch.optim.SGD(logistic_model.parameters(),lr=<span class="number">1e-3</span>,momentum=<span class="number">0.9</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">20000</span>):</span><br><span class="line">    <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">        x=Variable(x_data).cuda()</span><br><span class="line">        y=Variable(y_data).cuda()</span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        x=Variable(x_data)</span><br><span class="line">        y=Variable(y_data)</span><br><span class="line">    <span class="comment">#forward</span></span><br><span class="line">    out = logistic_model(x)</span><br><span class="line">    loss = criterion(out,y)</span><br><span class="line">    mask = out.ge(<span class="number">0.5</span>).<span class="built_in">float</span>() <span class="comment">#if out&gt;0.5,out=1,else out=0</span></span><br><span class="line">    acc = <span class="built_in">float</span>((mask == y_data).<span class="built_in">sum</span>().item()) / y_data.shape[<span class="number">0</span>]</span><br><span class="line">    <span class="comment">#backward</span></span><br><span class="line">    optimizer.zero_grad()</span><br><span class="line">    loss.backward()</span><br><span class="line">    optimizer.step()</span><br><span class="line">    <span class="keyword">if</span>(epoch+<span class="number">1</span>)%<span class="number">2000</span> ==<span class="number">0</span>:</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;*&#x27;</span>*<span class="number">10</span>)</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;Epoch: &#123;&#125;,Loss: &#123;:.4f&#125;,Acc: &#123;:.4f&#125;&#x27;</span>.<span class="built_in">format</span>(epoch+<span class="number">1</span>,loss,acc))</span><br><span class="line"></span><br><span class="line"><span class="comment"># 画线w1x+w2y+b=0</span></span><br><span class="line">w0,w1 = logistic_model.lr.weight[<span class="number">0</span>]</span><br><span class="line">b = logistic_model.lr.bias.data[<span class="number">0</span>]</span><br><span class="line">plot_x = np.arange(<span class="number">30</span>,<span class="number">100</span>,<span class="number">0.1</span>)</span><br><span class="line">w0=w0.data</span><br><span class="line">w1=w1.data</span><br><span class="line">plot_y = (-w0*plot_x-b) /w1</span><br><span class="line">plt.plot(plot_x,plot_y)</span><br><span class="line">plt.show()</span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20201023124844.png"></p><h3 id="3-4-简单多层全连接前向网络"><a href="#3-4-简单多层全连接前向网络" class="headerlink" title="3.4 简单多层全连接前向网络"></a>3.4 简单多层全连接前向网络</h3><h4 id="3-4-1-模拟神经元"><a href="#3-4-1-模拟神经元" class="headerlink" title="3.4.1 模拟神经元"></a>3.4.1 模拟神经元</h4><p>神经网络就是受到了模拟脑神经元的启发</p><h4 id="3-4-2-单层神经网络的分类器"><a href="#3-4-2-单层神经网络的分类器" class="headerlink" title="3.4.2 单层神经网络的分类器"></a>3.4.2 单层神经网络的分类器</h4><p>例如之前的Logistic回归，是使用了sigmoid函数作为激活函数的一层神经网络</p><h4 id="3-4-3-激活函数"><a href="#3-4-3-激活函数" class="headerlink" title="3.4.3 激活函数"></a>3.4.3 激活函数</h4><p>1.Sigmoid函数</p><p>$$\sigma (x)=\frac{1}{1+e^{-x}}$$</p><p>缺点：<br>（1）造成梯度消失。在靠近0，1两端，梯度几乎为0，导致没有信息来更新参数<br>（2）输出不是以0为均值。</p><p>2.Tanh</p><p>$$tanh(x)=2\sigma(2x)-1$$</p><p>Tanh激活函数是sigmoid函数的变形，将输入的数据转化到-1到1之间，解决了sigmoid函数第二个问题，但仍存在梯度消失的问题</p><p>3.ReLU</p><p>ReLU的数学表达式为$f(x)=max(0,x)$</p><p>优点：<br>（1）相比较sigmoid和tanh，ReLU可以极大地加速随机梯度下降法的收敛速度，因为是线性的，不存在梯度消失<br>（2）计算方法更简单</p><p>缺点：<br>训练的时候很脆弱，一个很大的梯度经过ReLU激活函数，更新参数之后，会使得这个神经元不会对任何数据有激活现象，之后再经过ReLU的梯度都是0，参数无法更新。可以通过设置较小的学习率来避免这个问题</p><p>4.Leaky ReLU</p><p>ReLU的变式，为了修复ReLU脆弱的缺点，将x&lt;0的部分变成一个很小的负的斜率，但是效果时好时不好</p><p>5.Maxout</p><p>$$f(x)=max(w_1x+b_1,w_2x+b_2)$$<br>ReLU只是Maxout中w1=0，b1=0的特殊形式</p><p>优点：包含ReLU的优点，避免了ReLU的脆弱性<br>缺点：参数存储变大</p><h4 id="3-4-4-神经网络的结构"><a href="#3-4-4-神经网络的结构" class="headerlink" title="3.4.4 神经网络的结构"></a>3.4.4 神经网络的结构</h4><p>神经网络是一个由神经元组成的无环图</p><p>nn.Linear(in,out，bias=False)是全连接神经网络层的函数</p><h4 id="3-4-5-模型的表示能力与容量"><a href="#3-4-5-模型的表示能力与容量" class="headerlink" title="3.4.5 模型的表示能力与容量"></a>3.4.5 模型的表示能力与容量</h4><p>在实际中，我们可能发现一个三层的全连接神经网络比一个两层的全连接神经网络表现更好，但是更深的网络结构对全连接神经网络效果提升表现不大。<br>我们需要注意的是，增大网络的层数和每层的节点数，相当于在增大网络的容量，容量的增大意味着网络有着更大的潜在表现能力。</p><p>但是当我们在做一个二分类问题时，更复杂的模型或许有着更复杂的形状，能将测试用例完美的分类，但是却忽略了潜在的数学关系，将噪声的干扰放大，这种效果被称为过拟合</p><h3 id="3-5-深度学习的基石：反向传播算法"><a href="#3-5-深度学习的基石：反向传播算法" class="headerlink" title="3.5 深度学习的基石：反向传播算法"></a>3.5 深度学习的基石：反向传播算法</h3><h4 id="3-5-1-链式法则"><a href="#3-5-1-链式法则" class="headerlink" title="3.5.1 链式法则"></a>3.5.1 链式法则</h4><p>求导的链式法则（高数知识）</p><h4 id="3-5-2-反向传播算法"><a href="#3-5-2-反向传播算法" class="headerlink" title="3.5.2 反向传播算法"></a>3.5.2 反向传播算法</h4><p>是链式求导法则的应用</p><p>局部求导，不断迭代传播</p><h3 id="3-6-各种优化算法的变式"><a href="#3-6-各种优化算法的变式" class="headerlink" title="3.6 各种优化算法的变式"></a>3.6 各种优化算法的变式</h3><h4 id="3-6-1-梯度下降法"><a href="#3-6-1-梯度下降法" class="headerlink" title="3.6.1 梯度下降法"></a>3.6.1 梯度下降法</h4><p>梯度下降的更新公式<br>$$x^i=x^{i-1}-\eta \nabla L(x^{i-1})$$</p><h4 id="3-6-2-梯度下降法的变式"><a href="#3-6-2-梯度下降法的变式" class="headerlink" title="3.6.2 梯度下降法的变式"></a>3.6.2 梯度下降法的变式</h4><p>1.SGD<br>随机梯度下降法，每次使用一批（batch）数据进行梯度的计算，而不是全部数据的梯度</p><p>2.Momentum<br>在随机梯度下降的同时，增加动量（momentum），帮助跳出一些鞍点或局部极小值点</p><p>3.Adagrad<br>自适应学习率的方法，公式是<br>$$w^{t+1}←w^{t}-\frac{\eta}{\sqrt{\sum_{i=0}^{t}(g^i)^2}+\varepsilon }$$</p><p>学习率在不断变小，但是在某些情况下会导致学习过早停止</p><p>4.RMSprop<br>一种非常有效的自适应学习率的改进方法，公式是<br>$$cache^t=\alpha * cache^{t-1}+(1-\alpha)(g^t)^2$$<br>$$w^{t+1}←w^{t}-\frac{\eta}{\sqrt{cache^t+\varepsilon}}g^t$$<br>其中α是衰减率，能有效避免Adagrad学习率一直递减太多的问题，能够更快地收敛</p><p>5.Adam<br>一种综合型学习方法，可以看成RMSprop加上momentum的学习方法</p><h3 id="3-7-处理数据和训练模型的技巧"><a href="#3-7-处理数据和训练模型的技巧" class="headerlink" title="3.7 处理数据和训练模型的技巧"></a>3.7 处理数据和训练模型的技巧</h3><h4 id="3-7-1-数据预处理"><a href="#3-7-1-数据预处理" class="headerlink" title="3.7.1 数据预处理"></a>3.7.1 数据预处理</h4><p>1.中心化<br>变成0均值</p><p>2.标准化<br>使得每个特征维度的最大值和最小值按比例缩放到-1到1之间</p><p>3.PCA（主成分分析）<br>将数据去相关性，将其投影到一个特征空间，取一些较大的，主要的特征向量来降低数据的维度</p><p>4.白噪声<br>将数据投影到一个特征空间，然后每个维度除以特征值来标准化这些数据</p><h4 id="3-7-2-权重初始化"><a href="#3-7-2-权重初始化" class="headerlink" title="3.7.2 权重初始化"></a>3.7.2 权重初始化</h4><p>1.全0初始化<br>不应该采用这种策略</p><p>2.随机初始化<br>包括了高斯随机化，均匀随机化</p><p>3.稀疏初始化</p><p>4.初始化偏置</p><p>5.批标准化</p><h4 id="3-7-3-防止过拟合"><a href="#3-7-3-防止过拟合" class="headerlink" title="3.7.3 防止过拟合"></a>3.7.3 防止过拟合</h4><p>1.正则化<br>2.Dropout</p><h3 id="3-8-多层全连接神经网络实现MNIST手写数字分类"><a href="#3-8-多层全连接神经网络实现MNIST手写数字分类" class="headerlink" title="3.8 多层全连接神经网络实现MNIST手写数字分类"></a>3.8 多层全连接神经网络实现MNIST手写数字分类</h3><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn,optim</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torchvision <span class="keyword">import</span> datasets,transforms</span><br><span class="line"></span><br><span class="line"><span class="comment">#带有批标准化和激活函数的三层全连接神经网络</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Batch_Net</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self,in_dim,n_hidden_1,n_hidden_2,out_dim</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(Batch_Net,self).__init__()</span><br><span class="line">        self.layer1 = nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNorm1d(n_hidden_1),nn.ReLU(<span class="literal">True</span>))</span><br><span class="line">        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNorm1d(n_hidden_2),nn.ReLU(<span class="literal">True</span>))</span><br><span class="line">        self.layer3 = nn.Sequential(nn.Linear(n_hidden_2,out_dim))</span><br><span class="line">    </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        x=self.layer1(x)</span><br><span class="line">        x=self.layer2(x)</span><br><span class="line">        x=self.layer3(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line">batch_size = <span class="number">64</span></span><br><span class="line">learning_rate = <span class="number">1e-2</span></span><br><span class="line">num_epoch = <span class="number">20</span></span><br><span class="line"></span><br><span class="line"><span class="comment">#transforms.ToTensor()将图片转换成PyTorch中从处理的对象，并自动将图片标准化了，即范围0到1</span></span><br><span class="line"><span class="comment">#transforms.Normalize(均值，方差)，处理：减均值，除以方差</span></span><br><span class="line"><span class="comment">#图片为灰度图，只有一个通道，如果是三通道则为transforms.Normalize([a,b,c],[d,e,f])</span></span><br><span class="line">data_tf = transforms.Compose(</span><br><span class="line">    [transforms.ToTensor(),transforms.Normalize([<span class="number">0.5</span>],[<span class="number">0.5</span>])]</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 获取数据集</span></span><br><span class="line">train_dataset = datasets.MNIST(root=<span class="string">&quot;./data&quot;</span>,train=<span class="literal">True</span>,transform=data_tf,download=<span class="literal">True</span>)</span><br><span class="line">test_dataset = datasets.MNIST(root=<span class="string">&quot;./data&quot;</span>,train=<span class="literal">False</span>,transform=data_tf)</span><br><span class="line"><span class="comment"># 数据迭代器，传入数据集和batch_size，通过shuffle=True来表示是否将数据打乱</span></span><br><span class="line">train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=<span class="literal">True</span>)</span><br><span class="line">test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line">model = Batch_Net(<span class="number">28</span>*<span class="number">28</span>,<span class="number">300</span>,<span class="number">100</span>,<span class="number">10</span>)</span><br><span class="line"><span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">    model = model.cuda()</span><br><span class="line"></span><br><span class="line">criterion = nn.CrossEntropyLoss() <span class="comment">#交叉熵</span></span><br><span class="line"><span class="comment"># 优化</span></span><br><span class="line">optimizer = optim.SGD(model.parameters(),lr=learning_rate)</span><br><span class="line"></span><br><span class="line"><span class="comment">#训练</span></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(num_epoch):</span><br><span class="line">    eval_loss = <span class="number">0</span></span><br><span class="line">    eval_acc = <span class="number">0</span></span><br><span class="line">    <span class="keyword">for</span> data <span class="keyword">in</span> train_loader:</span><br><span class="line">        img,label=data</span><br><span class="line">        img = img.view(img.size(<span class="number">0</span>),-<span class="number">1</span>)</span><br><span class="line">        <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">            img = Variable(img).cuda()</span><br><span class="line">            label = Variable(label).cuda()</span><br><span class="line">        <span class="keyword">else</span>:</span><br><span class="line">            img = Variable(img)</span><br><span class="line">            label = Variable(label)</span><br><span class="line">        out=model(img)</span><br><span class="line">        loss = criterion(out,label)</span><br><span class="line">        <span class="comment"># backward</span></span><br><span class="line">        optimizer.zero_grad() <span class="comment">#置0</span></span><br><span class="line">        loss.backward() <span class="comment">#求梯度</span></span><br><span class="line">        optimizer.step() <span class="comment">#更新所有的参数，梯度下降</span></span><br><span class="line">        <span class="comment">#acc</span></span><br><span class="line">        eval_loss +=loss*label.size(<span class="number">0</span>)</span><br><span class="line">        _,pred = torch.<span class="built_in">max</span>(out,<span class="number">1</span>)</span><br><span class="line">        num_correct = (pred == label).<span class="built_in">sum</span>()</span><br><span class="line">        eval_acc +=num_correct</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;Epoch:&#123;&#125;,Loss: &#123;:.6f&#125;,Acc:&#123;:.6f&#125;&#x27;</span>.<span class="built_in">format</span>(epoch,eval_loss/(<span class="built_in">len</span>(train_dataset)),<span class="built_in">float</span>(eval_acc)/(<span class="built_in">len</span>(train_dataset))))</span><br><span class="line"></span><br><span class="line">        </span><br><span class="line"><span class="comment">#测试</span></span><br><span class="line">model.<span class="built_in">eval</span>()</span><br><span class="line">eval_loss = <span class="number">0</span></span><br><span class="line">eval_acc = <span class="number">0</span></span><br><span class="line"><span class="keyword">for</span> data <span class="keyword">in</span> test_loader:</span><br><span class="line">    img,label=data</span><br><span class="line">    img = img.view(img.size(<span class="number">0</span>),-<span class="number">1</span>)</span><br><span class="line">    <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">        img = Variable(img).cuda()</span><br><span class="line">        label = Variable(label).cuda()</span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        img = Variable(img)</span><br><span class="line">        label = Variable(label)</span><br><span class="line">    out=model(img)</span><br><span class="line">    loss = criterion(out,label)</span><br><span class="line">    eval_loss +=loss.data*label.size(<span class="number">0</span>)</span><br><span class="line">    _,pred = torch.<span class="built_in">max</span>(out,<span class="number">1</span>)</span><br><span class="line">    num_correct = (pred == label).<span class="built_in">sum</span>()</span><br><span class="line">    eval_acc +=num_correct.data</span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(<span class="string">&#x27;Test Loss: &#123;:.6f&#125;,Acc:&#123;:.6f&#125;&#x27;</span>.<span class="built_in">format</span>(eval_loss/(<span class="built_in">len</span>(test_dataset)),<span class="built_in">float</span>(eval_acc)/(<span class="built_in">len</span>(test_dataset))))</span><br></pre></td></tr></table></figure><h2 id="第四章-卷积神经网络"><a href="#第四章-卷积神经网络" class="headerlink" title="第四章 卷积神经网络"></a>第四章 卷积神经网络</h2><p>1998年由Yann Lecun提出，2012年Alex Krizhecsky凭借它赢得了ImageNet挑战赛</p><h3 id="4-1-主要任务及起源"><a href="#4-1-主要任务及起源" class="headerlink" title="4.1 主要任务及起源"></a>4.1 主要任务及起源</h3><p>对于计算机视觉，主要用提取图像中的特征</p><h3 id="4-2-卷积神经网络的原理和结构"><a href="#4-2-卷积神经网络的原理和结构" class="headerlink" title="4.2 卷积神经网络的原理和结构"></a>4.2 卷积神经网络的原理和结构</h3><p>一，卷积神经网络的三种思想</p><p>1.局部性</p><p>对于图片而言，需要检测图片中的特征来决定图片的类别，通常情况下这些特征都不是由整张图片决定的，而是由一些局部的区域决定的</p><p>2.相同性</p><p>对不同图片，如果具有同样的特征，这些特征会出现在不同位置，但特征检测所作的操作几乎一样</p><p>3.不变性</p><p>对于一张大图片，如果进行下采样，那么图片的性质基本保持不变</p><p>二，卷积神经网络的层结构</p><p>对于全连接神经网络，其由一系列隐藏层构成，每个隐藏层由若干个神经元构成，其中每个神经元都和前一层的所有神经元相关联，但是每一层中的神经元是相互独立的。全连接神经网络在处理图片时，比如在minist数据集上，图片大小是28×28，那么每层的单个神经元的权重数目就是28×28=784，但这知识一张小图片，且只有一个通道，如果是大图片，那么就会导致参数增长特别快，所以全连接神经网络在处理图像并不是好的选择</p><p>而卷积神经网络是一个3D容量的神经元，每个神经元由三个维度排列：宽带，高度和深度。如果输入的图片是32×32×3，那么这张图片的宽度就是32，高度也是32，深度是3</p><p>卷积神经网络的主要层结构有三个：卷积层，池化层，全连接层，通过堆叠这些层结构形成了一个完整的卷积神经网络结构，其中一些层包含参数（如：卷积层，全连接层），一些层不包含参数（如：激活层，池化层）。</p><h4 id="4-2-1-卷积层"><a href="#4-2-1-卷积层" class="headerlink" title="4.2.1 卷积层"></a>4.2.1 卷积层</h4><p>卷积层是卷积神经网络的核心</p><p>1.概述</p><p>卷积神经网络的参数，是由一些可学习的滤波器集合构成，每个滤波器在空间上（宽度和高度）都比较小，但深度和输入数据的深度保持一致。在前向传播时，让每个滤波器都在输入数据的宽度和高度上滑动（卷积），然后计算整个滤波器和输入数据任意一处的内积。<br>当滤波器沿着输入数据的宽度和高度滑动时，会生成一个二维的激活图。每个卷积层上，会有一整个集合的滤波器，这样会形成多个二维的不同的激活图，将这些激活图在深度方向堆叠起来形成卷积层的输出</p><p>2.局部连接</p><p>与神经元连接的空间大小叫做神经元的感受野，其大小是一个人为设置的超参数，其实是滤波器的宽和高</p><p>3.空间排列</p><p>卷积层的输出深度是一个超参数，与使用的滤波器数量一致，并且在滑动滤波器的时候必须指定步长</p><p>4.边界填充</p><p>可以将输入数据用0在边界进行填充，用来控制输出数据在空间上的尺寸，输出的尺寸可以用一个公式来计算，$\frac{W-F+2P}{S}+1$，其中W是输入的数据大小，F表示卷积层中神经元的感受野尺寸，S表示步长，P表示边界填充0的数量</p><p>5.步长的限制</p><p>步长的选择是有所限制的。当输入尺寸W是10时，如果不使用0填充，即P=0，滤波器尺寸F=3，这样步长S=2就行不通，因为(10-3+0)/2+1=4.5，不是一个整数，说明神经元不能整齐对称地滑过输入数据体，这样的超参数是无效的</p><p>6.参数共享</p><p>输出体数据在深度切片上所有的权重都使用同一个权重向量，那么卷积层在向前传播的过程中每个深度切片都可以看成是神经元的权重对输入数据体做卷积，这也就是为什么把这些3D的权重集合称为滤波器或者卷积核</p><p>7.总结</p><p>卷积层的性质</p><ul><li>（1）输入数据体尺寸是W1×H1×D1</li><li>（2）4个超参数：卷积核数量K，卷积核空间尺寸F，滑动步长S，零填充的数量P</li><li>（3）输出数据体的尺寸为W2×H2×D2，其中$W_2=\frac{W_1-F+2P}{S}+1$,$H_2=\frac{H_1-F+2P}{S}+1$,D2=K</li><li>（4）由于参数共享，每个卷积核包含的权重数目为F×F×D1，卷积层一共有F×F×D1×K个权重和K个偏置</li><li>（5）在输出体数据中，第d个深度切片（空间尺寸是W2×H2），用第d个卷积器和输入数据进行有效卷积运算的结果，再加上第d个偏置</li></ul><p>对于卷积神经网络的一些超参数，常见的设置是F=3，S=1，P=1</p><h4 id="4-2-2-池化层"><a href="#4-2-2-池化层" class="headerlink" title="4.2.2 池化层"></a>4.2.2 池化层</h4><p>通常或者卷积层之间周期性插入一个池化层，作用是逐渐减低数据体的空间尺寸，这样能减少网络中参数的数量，减少计算资源耗费，同时也能有效地控制过拟合</p><p>步骤：设定一个空间窗口，不断滑动窗口，取这些窗口中的最大值作为输出结果</p><p>池化层之所有有效，是因为之前介绍的图片特征具有不变性，也就是通过下采样不会丢失图片拥有的特征</p><p>常用的池化层形式是尺寸为2×2的窗口，滑动步长是2，对图像进行下采样，将其中75%的激活信息都丢掉，选择其中最大的保留，池化层很少引入零填充</p><p>除最大值池化外，还有平均池化，或者L2范数池化，实际证明，最大池化效果最好，平均池化一般放在卷积神经网络最后一层</p><h4 id="4-2-3-全连接层"><a href="#4-2-3-全连接层" class="headerlink" title="4.2.3 全连接层"></a>4.2.3 全连接层</h4><p>全连接层的每个神经元与前一层所有的神经元全部连接，在这个过程中为了防止过拟合会引入<code>Dropout</code>。在进入全连接层之前，使用全局平均池化能够有效地降低过拟合</p><h4 id="4-2-4-卷积神经网络的基本形式"><a href="#4-2-4-卷积神经网络的基本形式" class="headerlink" title="4.2.4 卷积神经网络的基本形式"></a>4.2.4 卷积神经网络的基本形式</h4><p>卷积神经网络最常见的形式就是将一些卷积层和<code>ReLU</code>层放在一起，有可能在<code>ReLU</code>层前面加上批标准化层，随后紧跟池化层，再不断重复，直到图像被缩小到一个足够小的尺寸，然后将特征图展开，连接几层全连接层，最后输出结果</p><p>1.小滤波器的有效性</p><p>2.网络的尺寸</p><p>经验<br>（1）输入层：一般而言，输入层的大小应该能够被2整除很多次，常用的数字包括32，44，96，224<br>（2）卷积层：卷积层应该尽可能使用小尺寸，比如3×3或5×5，滑动步长取1。7×7通常用在第一个面对原始图像的卷积层上<br>（3）池化层：池化层负责对输入的数据空间维度进行下采样，常用的设置使用2×2的感受野做最大值池化，步长取2<br>（4）零填充：零填充的使用可以让卷积层的输入和输出在空间上的维度保持一致</p><h3 id="4-3-Pytorch卷积模块"><a href="#4-3-Pytorch卷积模块" class="headerlink" title="4.3 Pytorch卷积模块"></a>4.3 Pytorch卷积模块</h3><h4 id="4-3-1-卷积层"><a href="#4-3-1-卷积层" class="headerlink" title="4.3.1 卷积层"></a>4.3.1 卷积层</h4><p><code>nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias)</code><br>其中</p><ul><li><code>in_channels</code>对应输入数据体的深度</li><li><code>out_channels</code>对应输出数据体的深度</li><li><code>kernel_size</code>表示滤波器（卷积核）的大小，例如：<code>kernel_size=3</code>或<code>kernel_size=(3,2)</code></li><li><code>stride</code>表示滑动步长，默认<code>1</code></li><li><code>padding=0</code>表示四周不进行零填充，<code>padding=1</code>表示四周进行<code>1</code>个像素点的零填充，默认<code>0</code></li><li><code>bias</code>是一个布尔值，默认为<code>True</code>，表示使用偏置</li><li><code>groups</code>表示输出数据体深度上的联系，默认<code>groups=1</code>，即所有的输出和输入都是相关联的，如果<code>groups=2</code>表示输入的深度被分割成两份，输出的深度也被分割成两份，他们之间分别对应起来，所以要求输出和输入都必须要能被<code>groups</code>整除</li><li><code>dilation</code>表示卷积对于输入数据体的空间间隔，默认为<code>1</code></li></ul><h4 id="4-3-2-池化层"><a href="#4-3-2-池化层" class="headerlink" title="4.3.2 池化层"></a>4.3.2 池化层</h4><p><code>nn.MaxPool2d(kernel_size,stride,padding,dilation,return_indices,ceil_model)</code><br>其中</p><ul><li><code>kernel_size</code>,<code>stride</code>,<code>padding</code>,<code>dilation</code>和卷积层相同</li><li><code>return_indices</code>表示是否返回最大值所处的下标，默认为<code>False</code></li><li><code>ceil_mode</code>表示使用一些方格代替层结构，默认<code>False</code></li></ul><p><code>nn.AvgPool2d()</code>表示均值池化，里面的参数和MaxPool2d类似，但多一个参数<code>count_include_pad</code>表示计算均值的时候是否包含零填充，默认为<code>True</code></p><p>其他还有<code>nn.LPPool2d()</code>,<code>nn.AdaptiveMaxPool2d()</code></p><p><strong>下面是一个简单的多层卷积神经网络</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">SimpleCNN</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(SimpleCNN,self).__init__()</span><br><span class="line">        layer1 = nn.Sequential()</span><br><span class="line">        layer1.add_module(<span class="string">&#x27;conv1&#x27;</span>,nn.Conv2d(<span class="number">3</span>,<span class="number">32</span>,<span class="number">3</span>,<span class="number">1</span>,padding=<span class="number">1</span>))</span><br><span class="line">        layer1.add_module(<span class="string">&#x27;relu1&#x27;</span>,nn.ReLU(<span class="literal">True</span>))</span><br><span class="line">        layer1.add_module(<span class="string">&#x27;pool1&#x27;</span>,nn.MaxPool2d(<span class="number">2</span>,<span class="number">2</span>))</span><br><span class="line">        self.layer1=layer1</span><br><span class="line"></span><br><span class="line">        layer2 = nn.Sequential()</span><br><span class="line">        layer2.add_module(<span class="string">&#x27;conv2&#x27;</span>,nn.Conv2d(<span class="number">32</span>,<span class="number">64</span>,<span class="number">3</span>,<span class="number">1</span>,padding=<span class="number">1</span>))</span><br><span class="line">        layer2.add_module(<span class="string">&#x27;relu2&#x27;</span>,nn.ReLU(<span class="literal">True</span>))</span><br><span class="line">        layer2.add_module(<span class="string">&#x27;pool2&#x27;</span>,nn.MaxPool2d(<span class="number">2</span>,<span class="number">2</span>))</span><br><span class="line">        self.layer2=layer2</span><br><span class="line">        </span><br><span class="line">        layer3 = nn.Sequential()</span><br><span class="line">        layer3.add_module(<span class="string">&#x27;conv3&#x27;</span>,nn.Conv2d(<span class="number">64</span>,<span class="number">128</span>,<span class="number">3</span>,<span class="number">1</span>,padding=<span class="number">1</span>))</span><br><span class="line">        layer3.add_module(<span class="string">&#x27;relu3&#x27;</span>,nn.ReLU(<span class="literal">True</span>))</span><br><span class="line">        layer3.add_module(<span class="string">&#x27;pool3&#x27;</span>,nn.MaxPool2d(<span class="number">2</span>,<span class="number">2</span>))</span><br><span class="line">        self.layer3=layer3</span><br><span class="line"></span><br><span class="line">        layer4 = nn.Sequential()</span><br><span class="line">        layer4.add_module(<span class="string">&#x27;fc1&#x27;</span>,nn.Linear(<span class="number">2048</span>,<span class="number">512</span>))</span><br><span class="line">        layer4.add_module(<span class="string">&#x27;fc_relu1&#x27;</span>,nn.ReLU(<span class="literal">True</span>))</span><br><span class="line">        layer4.add_module(<span class="string">&#x27;fc2&#x27;</span>,nn.Linear(<span class="number">512</span>,<span class="number">64</span>))</span><br><span class="line">        layer4.add_module(<span class="string">&#x27;fc_relu2&#x27;</span>,nn.ReLU(<span class="literal">True</span>))</span><br><span class="line">        layer4.add_module(<span class="string">&#x27;fc3&#x27;</span>,nn.Linear(<span class="number">64</span>,<span class="number">10</span>))</span><br><span class="line">        self.layer4=layer4</span><br><span class="line">    </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        conv1 = self.layer1(x)</span><br><span class="line">        conv2 = self.layer2(conv1)</span><br><span class="line">        conv3 = self.layer3(conv2)</span><br><span class="line">        fc_input = conv3.view(conv3.size(<span class="number">0</span>),-<span class="number">1</span>)</span><br><span class="line">        fc_out = slef.layer4(fc_input)</span><br><span class="line">        <span class="keyword">return</span> fc_out</span><br><span class="line"></span><br><span class="line">model = SimpleCNN()</span><br><span class="line"><span class="built_in">print</span>(model)</span><br></pre></td></tr></table></figure><h4 id="4-3-3-提取层结构"><a href="#4-3-3-提取层结构" class="headerlink" title="4.3.3 提取层结构"></a>4.3.3 提取层结构</h4><p>nn.Module具有几个重要属性</p><ul><li><code>children()</code>，会返回下一级模块的迭代器，比如上面这个模型，直会返回在<code>self.layer1</code>,<code>slef.layer2</code>,<code>slef.layer3</code>以及<code>self.layer4</code>上的迭代器，不会返回他们内部的东西</li><li><code>modules()</code>，会返回模型中所有模块的迭代器，这样就有了一个好处，即它能够访问到最内层，比如<code>self.layer1.conv1</code>这个模块</li><li><code>named_children()</code>和<code>named_modules()</code>不仅会返回模块的迭代器，还会返回网络层的名字</li></ul><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">#提取前面两层</span></span><br><span class="line"><span class="built_in">print</span>(nn.Sequential(*<span class="built_in">list</span>(model.children())[:<span class="number">2</span>]))</span><br></pre></td></tr></table></figure><p><strong>提取所有的卷积层</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">conv_model = nn.Sequential()</span><br><span class="line"><span class="keyword">for</span> layer <span class="keyword">in</span> model.named_modules():</span><br><span class="line">    <span class="keyword">if</span> <span class="built_in">isinstance</span>(layer[<span class="number">1</span>],nn.Conv2d):</span><br><span class="line">        conv_model.add_module(layer[<span class="number">0</span>].split(<span class="string">&#x27;.&#x27;</span>)[-<span class="number">1</span>],layer[<span class="number">1</span>])</span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(conv_model)</span><br></pre></td></tr></table></figure><h4 id="4-3-4-提取参数及自定义初始化"><a href="#4-3-4-提取参数及自定义初始化" class="headerlink" title="4.3.4 提取参数及自定义初始化"></a>4.3.4 提取参数及自定义初始化</h4><p><code>nn.Module</code>关于参数的属性</p><ul><li><code>named_parameters()</code>，给出网络层的名字和参数的迭代器</li><li><code>parameters()</code>，给出一个网络的全部参数的迭代器</li></ul><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> param <span class="keyword">in</span> model.named_parameters():</span><br><span class="line">    <span class="built_in">print</span>(param[<span class="number">0</span>])</span><br></pre></td></tr></table></figure><p><strong>对权重初始化</strong>，因为权重是Variable，只需要取出data属性就能处理</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> m <span class="keyword">in</span> model.modules():</span><br><span class="line">    <span class="keyword">if</span> <span class="built_in">isinstance</span>(m,nn.Conv2d):</span><br><span class="line">        nn.init.normal(m.weight.data)</span><br><span class="line">        nn.init.xavier_normal(m.weight.data)</span><br><span class="line">        nn.init.kaiming_normal(m.weight.data)<span class="comment">#卷积层参数初始化</span></span><br><span class="line">        m.bias.data.fill_(<span class="number">0</span>)</span><br><span class="line">    <span class="keyword">elif</span> <span class="built_in">isinstance</span>(m,nn.Linear):</span><br><span class="line">        m.weight.data.normal_()<span class="comment">#全连接层参数初始化</span></span><br></pre></td></tr></table></figure><h3 id="4-4-卷积神经网络案例分析"><a href="#4-4-卷积神经网络案例分析" class="headerlink" title="4.4 卷积神经网络案例分析"></a>4.4 卷积神经网络案例分析</h3><h4 id="4-4-1-LeNet"><a href="#4-4-1-LeNet" class="headerlink" title="4.4.1 LeNet"></a>4.4.1 LeNet</h4><p>LeNet是整个卷积神经网络的开山之作，共有7层，其中2层卷积和2层池化层交替出现，最后输出3层全连接层得到整体的效果</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Lenet</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(Lenet,self).__init__()</span><br><span class="line">        layer1 = nn.Sequential()</span><br><span class="line">        layer1.add_module(<span class="string">&#x27;conv1&#x27;</span>,nn.Conv2d(<span class="number">1</span>,<span class="number">6</span>,<span class="number">3</span>,padding=<span class="number">1</span>))</span><br><span class="line">        layer1.add_module(<span class="string">&#x27;pool1&#x27;</span>,nn.MaxPool2d(<span class="number">2</span>,<span class="number">2</span>))</span><br><span class="line">        self.layer1 = layer1</span><br><span class="line">        </span><br><span class="line">        layer2 = nn.Sequential()</span><br><span class="line">        layer2.add_module(<span class="string">&#x27;conv2&#x27;</span>,nn.Conv2d(<span class="number">6</span>,<span class="number">16</span>,<span class="number">5</span>))</span><br><span class="line">        layer2.add_module(<span class="string">&#x27;pool2&#x27;</span>,nn.MaxPool2d(<span class="number">2</span>,<span class="number">2</span>))</span><br><span class="line">        self.layer2 = layer2</span><br><span class="line">        </span><br><span class="line">        layer3 = nn.Sequential()</span><br><span class="line">        layer3.add_module(<span class="string">&#x27;fc1&#x27;</span>,nn.Linear(<span class="number">400</span>,<span class="number">120</span>))</span><br><span class="line">        layer3.add_module(<span class="string">&#x27;fc2&#x27;</span>,nn.Linear(<span class="number">120</span>,<span class="number">84</span>))</span><br><span class="line">        layer3.add_module(<span class="string">&#x27;fc3&#x27;</span>,nn.Linear(<span class="number">84</span>,<span class="number">10</span>))</span><br><span class="line">        self.layer3 = layer3</span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        x = self.layer1(x)</span><br><span class="line">        x = self.layer2(x)</span><br><span class="line">        x = x.view(x.size(<span class="number">0</span>),-<span class="number">1</span>) <span class="comment"># 将第二次卷积的输出拉伸为一行</span></span><br><span class="line">        x = self.layer3(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br></pre></td></tr></table></figure><h4 id="4-4-2-AlexNet"><a href="#4-4-2-AlexNet" class="headerlink" title="4.4.2 AlexNet"></a>4.4.2 AlexNet</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">AlexNet</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self,num_classes</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(AlexNet,self).__init__()</span><br><span class="line">        self.features = nn.Sequential(</span><br><span class="line">            nn.Conv2d(<span class="number">3</span>,<span class="number">64</span>,kernel_size=<span class="number">11</span>,stride=<span class="number">4</span>,padding=<span class="number">2</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(kernel_size=<span class="number">3</span>,stride=<span class="number">2</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">64</span>,<span class="number">192</span>,kernel_size=<span class="number">5</span>,padding=<span class="number">2</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(kernel_size=<span class="number">3</span>,stride=<span class="number">2</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">192</span>,<span class="number">384</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">384</span>,<span class="number">256</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">256</span>,<span class="number">256</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(kernel_size=<span class="number">3</span>,stride=<span class="number">2</span>),</span><br><span class="line">        )</span><br><span class="line">        self.classifier = nn.Sequential(</span><br><span class="line">            nn.Dropout(),</span><br><span class="line">            nn.Linear(<span class="number">256</span>*<span class="number">6</span>*<span class="number">6</span>,<span class="number">4096</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>),</span><br><span class="line">            nn.Dropout(),</span><br><span class="line">            nn.Linear(<span class="number">4096</span>,<span class="number">4096</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>),</span><br><span class="line">            nn.Linear(<span class="number">4096</span>,num_classes),</span><br><span class="line">        )</span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        x = self.features(x)</span><br><span class="line">        x = x.view(x.size(<span class="number">0</span>),<span class="number">256</span>*<span class="number">6</span>*<span class="number">6</span>)</span><br><span class="line">        x = self.classifier(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br></pre></td></tr></table></figure><h4 id="4-4-3-VGGNet"><a href="#4-4-3-VGGNet" class="headerlink" title="4.4.3 VGGNet"></a>4.4.3 VGGNet</h4><p>使用了更小的滤波器，同时使用了更深的结构</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">VGG</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self,num_classes</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(VGG,self).__init__()</span><br><span class="line">        self.features = nn.Sequential(</span><br><span class="line">            nn.Conv2d(<span class="number">3</span>,<span class="number">64</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">64</span>,<span class="number">64</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(kernel_size=<span class="number">2</span>,stride=<span class="number">2</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">64</span>,<span class="number">128</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">128</span>,<span class="number">128</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(kernel_size=<span class="number">2</span>,stride=<span class="number">2</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">128</span>,<span class="number">256</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">256</span>,<span class="number">256</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">256</span>,<span class="number">256</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(kernel_size=<span class="number">2</span>,stride=<span class="number">2</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">256</span>,<span class="number">512</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">512</span>,<span class="number">512</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">512</span>,<span class="number">512</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(kernel_size=<span class="number">2</span>,stride=<span class="number">2</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">512</span>,<span class="number">512</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">512</span>,<span class="number">512</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">512</span>,<span class="number">512</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(kernel_size=<span class="number">2</span>,stride=<span class="number">2</span>),</span><br><span class="line">        )</span><br><span class="line">        self.classifier = nn.Sequential(</span><br><span class="line">            nn.Linear(<span class="number">512</span>*<span class="number">7</span>*<span class="number">7</span>,<span class="number">4096</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Dropout(),</span><br><span class="line">            nn.Linear(<span class="number">4096</span>,<span class="number">4096</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Dropout(),</span><br><span class="line">            nn.Linear(<span class="number">4096</span>,num_classes),</span><br><span class="line">        )</span><br><span class="line">        self._initialize_weights()</span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        x = self.features(x)</span><br><span class="line">        x = x.view(x.size(<span class="number">0</span>),-<span class="number">1</span>)</span><br><span class="line">        x = self.classifier(x)</span><br></pre></td></tr></table></figure><h4 id="4-4-4-GoogleNet"><a href="#4-4-4-GoogleNet" class="headerlink" title="4.4.4 GoogleNet"></a>4.4.4 GoogleNet</h4><p>也叫InceptionNet，采用了比VGG更深的网络结构，一共22层，但是参数却比AlexNet少了12倍，同时有很高的计算效率，因为它采用了一种很有效的Inception模块，而且没有全连接层。</p><p><strong>Inception模块</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">BasicConv2d</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self,in_channels,out_channels,**kwargs</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(BasicConv2d,self).__init__()</span><br><span class="line">        self.conv = nn.Conv2d(in_channels,out_channels,bias=<span class="literal">False</span>,**kwargs)</span><br><span class="line">        self.bn = nn.BatchNorm2d(out_channels,eps=<span class="number">0.001</span>)</span><br><span class="line">    </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        x = self.conv(x)</span><br><span class="line">        x = self.bn(x)</span><br><span class="line">        <span class="keyword">return</span> F.relu(x,inplace=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Inception</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self,in_channels,pool_features</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(Inception,self).__init__()</span><br><span class="line">        self.branch1x1 = BasicConv2d(in_channels,<span class="number">64</span>,kernel_size=<span class="number">1</span>)</span><br><span class="line">        self.branch5x5_1 = BasicConv2d(in_channels,<span class="number">48</span>,kernel_size=<span class="number">1</span>)</span><br><span class="line">        self.branch5x5_2 = BasicConv2d(<span class="number">48</span>,<span class="number">64</span>,kernel_size=<span class="number">5</span>,padding=<span class="number">2</span>)</span><br><span class="line">        </span><br><span class="line">        self.branch3x3db1_1 = BasicConv2d(in_channels,<span class="number">64</span>,kernel_size=<span class="number">1</span>)</span><br><span class="line">        self.branch3x3db1_2 = BasicConv2d(<span class="number">64</span>,<span class="number">96</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>)</span><br><span class="line">        self.branch3x3db1_3 = BasicConv2d(<span class="number">96</span>,<span class="number">96</span>,kernel_size=<span class="number">3</span>,padding=<span class="number">1</span>)</span><br><span class="line">        </span><br><span class="line">        self.branch_pool = BasicConv2d(in_channels,pool_features,kernel_size=<span class="number">1</span>)</span><br><span class="line">    </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        branch1x1 = self.branch1x1(x)</span><br><span class="line">        </span><br><span class="line">        branch5x5 = self.branch5x5_1(x)</span><br><span class="line">        branch5x5 = self.branch5x5_2(branch5x5)</span><br><span class="line">        </span><br><span class="line">        branch3x3db1 = self.branch3x3db1_1(x)</span><br><span class="line">        branch3x3db1 = self.branch3x3db1_2(branch3x3db1)</span><br><span class="line">        branch3x3db1 = self.branch3x3db1_3(branch3x3db1)</span><br><span class="line">        </span><br><span class="line">        branch_pool = F.avg_pool2d(x,kernel_size=<span class="number">3</span>,stride=<span class="number">1</span>,padding=<span class="number">1</span>)</span><br><span class="line">        branch_pool = self.branch_pool(branch_pool)</span><br><span class="line">        </span><br><span class="line">        outputs = [branch1x1,branch5x5,branch3x3db1,branch_pool]</span><br><span class="line">        <span class="keyword">return</span> torch.cat(outputs,<span class="number">1</span>) <span class="comment">#按深度拼接</span></span><br></pre></td></tr></table></figure><h4 id="4-4-5-ResNet"><a href="#4-4-5-ResNet" class="headerlink" title="4.4.5 ResNet"></a>4.4.5 ResNet</h4><p>由微软研究院提出，通过残差模块能够成功地训练高达152层深的神经网络</p><p>ResNet 最初的设计灵感来自这个问题:在不断加深神经网络的时候，会出现一个Degradation ，即准确率会先上升然后达到饱和，再持续增加深度则会导致模型准确率下降。</p><p>这并不是过拟合的问题，因为不仅在验证集上误差增加，训练集本身误差也会增加，假设一个比较浅的网络达到了饱和的准确率，那么在后面加上几个恒等映射层，误差不会增加，也就说更深的模型起码不会使得模型效果下降。</p><p>这里提到的使用恒等映射直接将前一层输出传到后面的思想，就是 ResNet 的灵感来源。假设某个神经网络的输入是x， 期望输出是 H(x)，如果直接把输入x传到输出作为初始结果，那么此时需要学习的目标就是 F(x) = H (x) - x<br><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN/post/20201027143439.png"><br>左边是一个普通的网络，右边是一个 ResNet 的残差学习 单元， ResNet 相当于将学习目 标改变了.不再是学习一个完整的输出H ( x ) ， 而是学习输出和输入的差别H (x) - x，即残差。</p><p>除了这些比较出名的以外还有很多。并且并不需要重复造轮子，PyTorch内为我们实现了以上网络，都在<code>torchvision.model</code>里面，并且大部分网络都有预训练好的参数</p><h3 id="4-5-再实现MNIST手写数字分类"><a href="#4-5-再实现MNIST手写数字分类" class="headerlink" title="4.5 再实现MNIST手写数字分类"></a>4.5 再实现MNIST手写数字分类</h3><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn,optim</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torchvision <span class="keyword">import</span> datasets,transforms</span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">CNN</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(CNN,self).__init__()</span><br><span class="line">        self.layer1 = nn.Sequential(</span><br><span class="line">            nn.Conv2d(<span class="number">1</span>,<span class="number">16</span>,kernel_size=<span class="number">3</span>),</span><br><span class="line">            nn.BatchNorm2d(<span class="number">16</span>),<span class="comment"># 归一化处理，使得数据分布一致，避免梯度消失或梯度爆炸</span></span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>)</span><br><span class="line">        )</span><br><span class="line">        self.layer2 = nn.Sequential(</span><br><span class="line">            nn.Conv2d(<span class="number">16</span>,<span class="number">32</span>,kernel_size=<span class="number">3</span>),</span><br><span class="line">            nn.BatchNorm2d(<span class="number">32</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(kernel_size=<span class="number">2</span>,stride=<span class="number">2</span>)</span><br><span class="line">        )</span><br><span class="line">        </span><br><span class="line">        self.layer3 = nn.Sequential(</span><br><span class="line">            nn.Conv2d(<span class="number">32</span>,<span class="number">64</span>,kernel_size=<span class="number">3</span>),</span><br><span class="line">            nn.BatchNorm2d(<span class="number">64</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>)</span><br><span class="line">        )</span><br><span class="line">        </span><br><span class="line">        self.layer4 = nn.Sequential(</span><br><span class="line">            nn.Conv2d(<span class="number">64</span>,<span class="number">128</span>,kernel_size=<span class="number">3</span>),</span><br><span class="line">            nn.BatchNorm2d(<span class="number">128</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(kernel_size=<span class="number">2</span>,stride=<span class="number">2</span>)</span><br><span class="line">        )</span><br><span class="line">        </span><br><span class="line">        self.fc = nn.Sequential(</span><br><span class="line">            nn.Linear(<span class="number">128</span>*<span class="number">4</span>*<span class="number">4</span>,<span class="number">1024</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>),</span><br><span class="line">            nn.Linear(<span class="number">1024</span>,<span class="number">128</span>),</span><br><span class="line">            nn.ReLU(inplace=<span class="literal">True</span>),</span><br><span class="line">            nn.Linear(<span class="number">128</span>,<span class="number">10</span>)</span><br><span class="line">        )</span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        x  = self.layer1(x)</span><br><span class="line">        x  = self.layer2(x)</span><br><span class="line">        x  = self.layer3(x)</span><br><span class="line">        x  = self.layer4(x)</span><br><span class="line">        x = x.view(x.size(<span class="number">0</span>),-<span class="number">1</span>)</span><br><span class="line">        x  = self.fc(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line">    </span><br><span class="line">batch_size = <span class="number">64</span></span><br><span class="line">learning_rate = <span class="number">1e-2</span></span><br><span class="line">num_epoch = <span class="number">5</span></span><br><span class="line"></span><br><span class="line">data_tf = transforms.Compose(</span><br><span class="line">    [transforms.ToTensor(),transforms.Normalize([<span class="number">0.5</span>],[<span class="number">0.5</span>])]</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 获取数据集</span></span><br><span class="line">train_dataset = datasets.MNIST(root=<span class="string">&quot;./data&quot;</span>,train=<span class="literal">True</span>,transform=data_tf,download=<span class="literal">True</span>)</span><br><span class="line">test_dataset = datasets.MNIST(root=<span class="string">&quot;./data&quot;</span>,train=<span class="literal">False</span>,transform=data_tf)</span><br><span class="line"><span class="comment"># 数据迭代器，传入数据集和batch_size，通过shuffle=True来表示是否将数据打乱</span></span><br><span class="line">train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=<span class="literal">True</span>)</span><br><span class="line">test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line">model = CNN()</span><br><span class="line"><span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">    model = model.cuda()</span><br><span class="line"></span><br><span class="line">criterion = nn.CrossEntropyLoss()</span><br><span class="line">optimizer = optim.SGD(model.parameters(),lr=learning_rate)</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(num_epoch):</span><br><span class="line">    eval_loss = <span class="number">0.0</span></span><br><span class="line">    eval_acc = <span class="number">0.0</span></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;Epoch &#123;&#125;/&#123;&#125;&quot;</span>.<span class="built_in">format</span>(epoch,num_epoch))</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;-&quot;</span>*<span class="number">20</span>)</span><br><span class="line">    <span class="keyword">for</span> data <span class="keyword">in</span> train_loader:</span><br><span class="line">        img,label=data</span><br><span class="line">        <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">            img = Variable(img).cuda()</span><br><span class="line">            label = Variable(label).cuda()</span><br><span class="line">        <span class="keyword">else</span>:</span><br><span class="line">            img = Variable(img)</span><br><span class="line">            label = Variable(label)</span><br><span class="line">        out=model(img)</span><br><span class="line">        loss = criterion(out,label)</span><br><span class="line">        <span class="comment"># backward</span></span><br><span class="line">        optimizer.zero_grad() <span class="comment">#置0</span></span><br><span class="line">        loss.backward() <span class="comment">#求梯度</span></span><br><span class="line">        optimizer.step() <span class="comment">#更新所有的参数，梯度下降</span></span><br><span class="line">        <span class="comment">#acc</span></span><br><span class="line">        eval_loss += loss.data</span><br><span class="line">        _,pred = torch.<span class="built_in">max</span>(out,<span class="number">1</span>)</span><br><span class="line">        eval_acc += (pred == label).<span class="built_in">sum</span>()</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;Epoch:&#123;&#125;,Loss: &#123;:.4f&#125;,Acc:&#123;:.4f&#125;%&#x27;</span>.<span class="built_in">format</span>(epoch,eval_loss/(<span class="built_in">len</span>(train_dataset)),<span class="number">100</span>*<span class="built_in">float</span>(eval_acc)/(<span class="built_in">len</span>(train_dataset))))</span><br><span class="line"></span><br><span class="line">        </span><br><span class="line">PATH=<span class="string">&#x27;./minist_net.pth&#x27;</span></span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;Train finished!&quot;</span>)</span><br><span class="line">torch.save(model.state_dict(), PATH)</span><br></pre></td></tr></table></figure><h3 id="4-6-图像增强的方法"><a href="#4-6-图像增强的方法" class="headerlink" title="4.6 图像增强的方法"></a>4.6 图像增强的方法</h3><p>torchvision.transforms包括所有图像增强的方法</p><ul><li>Scale，对图片的尺寸进行缩小和放大</li><li>CenterCrop，对图像正中心进行给定大小的随机裁剪</li><li>RandomCrop，对图片进行给定大小的随机裁剪</li><li>RandomHorizaontalFlip，对图像进行概率为0.5的随机水平翻转</li><li>RandomSizedCrop，首先对图片进行随机尺寸的裁剪，然后对裁剪图片进行一个随即比例的缩放，最后将图片变成给定的大小</li><li>Pad，对图片进行边界零填充</li></ul><p>除此之外，还可以使用OpenCV或者PIL等第三方图形库来实现</p><h3 id="4-7-实现cifar10分类"><a href="#4-7-实现cifar10分类" class="headerlink" title="4.7 实现cifar10分类"></a>4.7 实现cifar10分类</h3><p>cifar10数据集中有60000张图片，每张图片的大小都是32×32的三通道彩色图</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br><span class="line">127</span><br><span class="line">128</span><br><span class="line">129</span><br><span class="line">130</span><br><span class="line">131</span><br><span class="line">132</span><br><span class="line">133</span><br><span class="line">134</span><br><span class="line">135</span><br><span class="line">136</span><br><span class="line">137</span><br><span class="line">138</span><br><span class="line">139</span><br><span class="line">140</span><br><span class="line">141</span><br><span class="line">142</span><br><span class="line">143</span><br><span class="line">144</span><br><span class="line">145</span><br><span class="line">146</span><br><span class="line">147</span><br><span class="line">148</span><br><span class="line">149</span><br><span class="line">150</span><br><span class="line">151</span><br><span class="line">152</span><br><span class="line">153</span><br><span class="line">154</span><br><span class="line">155</span><br><span class="line">156</span><br><span class="line">157</span><br><span class="line">158</span><br><span class="line">159</span><br><span class="line">160</span><br><span class="line">161</span><br><span class="line">162</span><br><span class="line">163</span><br><span class="line">164</span><br><span class="line">165</span><br><span class="line">166</span><br><span class="line">167</span><br><span class="line">168</span><br><span class="line">169</span><br><span class="line">170</span><br><span class="line">171</span><br><span class="line">172</span><br><span class="line">173</span><br><span class="line">174</span><br><span class="line">175</span><br><span class="line">176</span><br><span class="line">177</span><br><span class="line">178</span><br><span class="line">179</span><br><span class="line">180</span><br><span class="line">181</span><br><span class="line">182</span><br><span class="line">183</span><br><span class="line">184</span><br><span class="line">185</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">import</span> torchvision.transforms <span class="keyword">as</span> transforms</span><br><span class="line"><span class="keyword">import</span> torch.nn <span class="keyword">as</span> nn</span><br><span class="line"><span class="keyword">import</span> torch.nn.functional <span class="keyword">as</span> F</span><br><span class="line"><span class="keyword">import</span> torch.optim <span class="keyword">as</span> optim</span><br><span class="line"><span class="keyword">import</span> os</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"></span><br><span class="line"><span class="comment">#数据处理</span></span><br><span class="line">train_transform = transforms.Compose([</span><br><span class="line">    transforms.Scale(<span class="number">40</span>),</span><br><span class="line">    transforms.RandomHorizontalFlip(),</span><br><span class="line">    transforms.RandomCrop(<span class="number">32</span>),</span><br><span class="line">    transforms.ToTensor(),</span><br><span class="line">    transforms.Normalize([<span class="number">0.5</span>, <span class="number">0.5</span>, <span class="number">0.5</span>], [<span class="number">0.5</span>, <span class="number">0.5</span>, <span class="number">0.5</span>])</span><br><span class="line">])</span><br><span class="line"></span><br><span class="line">test_transform = transforms.Compose([</span><br><span class="line">    transforms.ToTensor(),</span><br><span class="line">    transforms.Normalize([<span class="number">0.5</span>, <span class="number">0.5</span>, <span class="number">0.5</span>], [<span class="number">0.5</span>, <span class="number">0.5</span>, <span class="number">0.5</span>])</span><br><span class="line">])</span><br><span class="line"></span><br><span class="line"><span class="comment">#数据集获取</span></span><br><span class="line">train_set = torchvision.datasets.CIFAR10(root=<span class="string">&#x27;./data&#x27;</span>, train=<span class="literal">True</span>, download=<span class="literal">True</span>, transform=train_transform)</span><br><span class="line">train_data = torch.utils.data.DataLoader(train_set, batch_size=<span class="number">32</span>, shuffle=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">test_set = torchvision.datasets.CIFAR10(root=<span class="string">&#x27;./data&#x27;</span>, train=<span class="literal">False</span>, download=<span class="literal">True</span>, transform=test_transform)</span><br><span class="line">test_data = torch.utils.data.DataLoader(test_set, batch_size=<span class="number">32</span>, shuffle=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line">classes = (<span class="string">&#x27;plane&#x27;</span>, <span class="string">&#x27;car&#x27;</span>, <span class="string">&#x27;bird&#x27;</span>, <span class="string">&#x27;cat&#x27;</span>,</span><br><span class="line">           <span class="string">&#x27;deer&#x27;</span>, <span class="string">&#x27;dog&#x27;</span>, <span class="string">&#x27;frog&#x27;</span>, <span class="string">&#x27;horse&#x27;</span>, <span class="string">&#x27;ship&#x27;</span>, <span class="string">&#x27;truck&#x27;</span>)</span><br><span class="line"><span class="comment">#3×3卷积层</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">conv3x3</span>(<span class="params">in_channel, out_channel, stride=<span class="number">1</span></span>):</span></span><br><span class="line">    <span class="keyword">return</span> nn.Conv2d(in_channel, out_channel, <span class="number">3</span>, stride=stride, padding=<span class="number">1</span>, bias=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">residual_block</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, in_channel, out_channel, same_shape=<span class="literal">True</span></span>):</span></span><br><span class="line">        <span class="built_in">super</span>(residual_block, self).__init__()</span><br><span class="line">        self.same_shape = same_shape</span><br><span class="line">        stride = <span class="number">1</span> <span class="keyword">if</span> self.same_shape <span class="keyword">else</span> <span class="number">2</span></span><br><span class="line">          </span><br><span class="line">        self.conv1 = conv3x3(in_channel, out_channel, stride=stride)</span><br><span class="line">        self.bn1 = nn.BatchNorm2d(out_channel)</span><br><span class="line">          </span><br><span class="line">        self.conv2 = conv3x3(out_channel, out_channel)</span><br><span class="line">        self.bn2 = nn.BatchNorm2d(out_channel)</span><br><span class="line">        <span class="keyword">if</span> <span class="keyword">not</span> self.same_shape:</span><br><span class="line">            self.conv3 = nn.Conv2d(in_channel, out_channel, <span class="number">1</span>, stride=stride)</span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        out = self.conv1(x)</span><br><span class="line">        out = F.relu(self.bn1(out), <span class="literal">True</span>)</span><br><span class="line">        out = self.conv2(out)</span><br><span class="line">        out = F.relu(self.bn2(out), <span class="literal">True</span>)</span><br><span class="line">          </span><br><span class="line">        <span class="keyword">if</span> <span class="keyword">not</span> self.same_shape:</span><br><span class="line">            x = self.conv3(x)</span><br><span class="line">        <span class="keyword">return</span> F.relu(x+out, <span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">resnet</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, in_channel, num_classes</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(resnet, self).__init__()</span><br><span class="line">        self.block1 = nn.Conv2d(in_channel, <span class="number">64</span>, <span class="number">7</span>, <span class="number">2</span>,<span class="number">3</span>) <span class="comment"># 32-7+2*3/2+1=16</span></span><br><span class="line">        self.block2 = nn.Sequential(</span><br><span class="line">            nn.MaxPool2d(<span class="number">3</span>, <span class="number">1</span>),</span><br><span class="line">            residual_block(<span class="number">64</span>, <span class="number">64</span>),</span><br><span class="line">            residual_block(<span class="number">64</span>, <span class="number">64</span>)</span><br><span class="line">        )</span><br><span class="line">        self.block3 = nn.Sequential(</span><br><span class="line">            residual_block(<span class="number">64</span>, <span class="number">128</span>, <span class="literal">False</span>),</span><br><span class="line">            residual_block(<span class="number">128</span>, <span class="number">128</span>)</span><br><span class="line">        )</span><br><span class="line">        self.block4 = nn.Sequential(</span><br><span class="line">            residual_block(<span class="number">128</span>, <span class="number">256</span>, <span class="literal">False</span>),</span><br><span class="line">            residual_block(<span class="number">256</span>, <span class="number">256</span>)</span><br><span class="line">        )</span><br><span class="line">        self.block5 = nn.Sequential(</span><br><span class="line">            residual_block(<span class="number">256</span>, <span class="number">512</span>, <span class="literal">False</span>),</span><br><span class="line">            residual_block(<span class="number">512</span>, <span class="number">512</span>)</span><br><span class="line">        )</span><br><span class="line">        self.avg_pool = nn.AvgPool2d(<span class="number">2</span>)</span><br><span class="line">        self.classifier = nn.Linear(<span class="number">512</span>, num_classes)</span><br><span class="line">          </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        x = self.block1(x)</span><br><span class="line">        <span class="comment">#print(x.shape)</span></span><br><span class="line">        x = self.block2(x)</span><br><span class="line">        <span class="comment">#print(x.shape)</span></span><br><span class="line">        x = self.block3(x)</span><br><span class="line">        <span class="comment">#print(x.shape)</span></span><br><span class="line">        x = self.block4(x)</span><br><span class="line">        <span class="comment">#print(x.shape)</span></span><br><span class="line">        x = self.block5(x)</span><br><span class="line">        <span class="comment">#print(x.shape)</span></span><br><span class="line">        x = self.avg_pool(x)</span><br><span class="line">        x = x.view(x.size(<span class="number">0</span>), -<span class="number">1</span>)</span><br><span class="line">        x = self.classifier(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line">PATH = <span class="string">&#x27;./cifar_net.pth&#x27;</span></span><br><span class="line">net = resnet(<span class="number">3</span>, <span class="number">10</span>)</span><br><span class="line"><span class="comment">#if os.path.exists(PATH):</span></span><br><span class="line"><span class="comment">#    net.load_state_dict(torch.load(PATH))</span></span><br><span class="line">criterion = nn.CrossEntropyLoss() <span class="comment">#交叉熵</span></span><br><span class="line">optimizer = optim.Adam(net.parameters(), lr=<span class="number">0.01</span>) </span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> datetime <span class="keyword">import</span> datetime</span><br><span class="line"></span><br><span class="line"><span class="comment">#计算正确率</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">get_acc</span>(<span class="params">output, label</span>):</span></span><br><span class="line">    total = output.shape[<span class="number">0</span>]</span><br><span class="line">    _, pred_label = output.<span class="built_in">max</span>(<span class="number">1</span>)</span><br><span class="line">    num_correct = (pred_label == label).<span class="built_in">sum</span>().data</span><br><span class="line">    <span class="keyword">return</span> <span class="built_in">float</span>(num_correct) / total</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">train</span>(<span class="params">net, train_data, valid_data, num_epochs, optimizer, criterion</span>):</span></span><br><span class="line">    <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">        net = net.cuda()</span><br><span class="line">    <span class="comment">#计时</span></span><br><span class="line">    prev_time = datetime.now()</span><br><span class="line">    <span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(num_epochs):</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&quot;*&quot;</span>*<span class="number">10</span>)</span><br><span class="line">        train_loss = <span class="number">0.0</span></span><br><span class="line">        train_acc = <span class="number">0.0</span></span><br><span class="line">        net = net.train() <span class="comment">#训练模式</span></span><br><span class="line">        <span class="keyword">for</span> data <span class="keyword">in</span> train_data:</span><br><span class="line">            im,label = data</span><br><span class="line">            <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">                im = Variable(im.cuda())</span><br><span class="line">                label = Variable(label.cuda())</span><br><span class="line">            <span class="keyword">else</span>:</span><br><span class="line">                im = Variable(im)</span><br><span class="line">                label = Variable(label)</span><br><span class="line">            <span class="comment">#forward</span></span><br><span class="line">            output = net(im)</span><br><span class="line">            loss = criterion(output, label)</span><br><span class="line">            <span class="comment">#forward</span></span><br><span class="line">            optimizer.zero_grad()</span><br><span class="line">            loss.backward()</span><br><span class="line">            optimizer.step()</span><br><span class="line">               </span><br><span class="line">            train_loss += loss.data</span><br><span class="line">            train_acc += get_acc(output, label)</span><br><span class="line">        <span class="comment">#计时</span></span><br><span class="line">        cur_time = datetime.now()</span><br><span class="line">        h, remainder = <span class="built_in">divmod</span>((cur_time-prev_time).seconds, <span class="number">3600</span>)</span><br><span class="line">        m, s = <span class="built_in">divmod</span>(remainder, <span class="number">60</span>)</span><br><span class="line">        time_str = <span class="string">&quot;Time %02d:%02d:%02d&quot;</span> % (h, m, s)</span><br><span class="line">        <span class="comment">#测试</span></span><br><span class="line">        <span class="keyword">if</span> valid_data <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line">            valid_loss = <span class="number">0.0</span></span><br><span class="line">            valid_acc = <span class="number">0.0</span></span><br><span class="line">            net = net.<span class="built_in">eval</span>() <span class="comment"># 切换测试模式</span></span><br><span class="line">            <span class="keyword">for</span> data <span class="keyword">in</span> valid_data:</span><br><span class="line">                im, label = data</span><br><span class="line">                <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">                    im = Variable(im.cuda())</span><br><span class="line">                    label = Variable(label.cuda())</span><br><span class="line">                <span class="keyword">else</span>:</span><br><span class="line">                    im = Variable(im)</span><br><span class="line">                    label = Variable(label)</span><br><span class="line">                output = net(im)</span><br><span class="line">                loss = criterion(output, label)</span><br><span class="line">                valid_loss += loss.item()</span><br><span class="line">                valid_acc += get_acc(output, label)</span><br><span class="line">            epoch_str = (</span><br><span class="line">                <span class="string">&quot;Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, &quot;</span></span><br><span class="line">                % (epoch, train_loss / <span class="built_in">len</span>(train_data),</span><br><span class="line">                   train_acc / <span class="built_in">len</span>(train_data), valid_loss / <span class="built_in">len</span>(valid_data),</span><br><span class="line">                   valid_acc / <span class="built_in">len</span>(valid_data)))</span><br><span class="line">        <span class="keyword">else</span>:</span><br><span class="line">            epoch_str = (<span class="string">&quot;Epoch %d. Train Loss: %f, Train Acc: %f, &quot;</span> %</span><br><span class="line">                         (epoch, train_loss / <span class="built_in">len</span>(train_data),</span><br><span class="line">                          train_acc / <span class="built_in">len</span>(train_data)))</span><br><span class="line">               </span><br><span class="line">        prev_time = cur_time</span><br><span class="line">        <span class="built_in">print</span>(epoch_str + time_str)</span><br><span class="line"></span><br><span class="line">train(net, train_data, test_data, <span class="number">10</span>, optimizer, criterion) </span><br><span class="line"><span class="built_in">print</span>(<span class="string">&#x27;Finished Training&#x27;</span>)</span><br><span class="line"></span><br><span class="line">torch.save(net.state_dict(), PATH)</span><br></pre></td></tr></table></figure><p><strong>测试</strong></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"></span><br><span class="line">test_set = torchvision.datasets.CIFAR10(root=<span class="string">&#x27;./data&#x27;</span>, train=<span class="literal">False</span>, download=<span class="literal">True</span>, transform=test_transform)</span><br><span class="line">test_data = torch.utils.data.DataLoader(test_set, batch_size=<span class="number">4</span>, shuffle=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 输出图像的函数</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">imshow</span>(<span class="params">img</span>):</span></span><br><span class="line">    img = img / <span class="number">2</span> + <span class="number">0.5</span>     <span class="comment"># unnormalize</span></span><br><span class="line">    npimg = img.numpy()</span><br><span class="line">    plt.imshow(np.transpose(npimg, (<span class="number">1</span>, <span class="number">2</span>, <span class="number">0</span>)))</span><br><span class="line">    plt.show()</span><br><span class="line"></span><br><span class="line">dataiter = <span class="built_in">iter</span>(test_data)</span><br><span class="line">images, labels = dataiter.<span class="built_in">next</span>()</span><br><span class="line"></span><br><span class="line"><span class="comment">#print(images.shape)</span></span><br><span class="line"><span class="comment"># 输出图片</span></span><br><span class="line">imshow(torchvision.utils.make_grid(images))</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&#x27;GroundTruth: &#x27;</span>, <span class="string">&#x27; &#x27;</span>.join(<span class="string">&#x27;%5s&#x27;</span> % classes[labels[j]] <span class="keyword">for</span> j <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">4</span>)))</span><br><span class="line"></span><br><span class="line">PATH = <span class="string">&#x27;./cifar_net.pth&#x27;</span></span><br><span class="line">net.load_state_dict(torch.load(PATH))</span><br><span class="line"></span><br><span class="line">outputs = net(images)</span><br><span class="line"></span><br><span class="line">_, predicted = torch.<span class="built_in">max</span>(outputs, <span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(<span class="string">&#x27;Predicted: &#x27;</span>, <span class="string">&#x27; &#x27;</span>.join(<span class="string">&#x27;%5s&#x27;</span> % classes[predicted[j]] <span class="keyword">for</span> j <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">4</span>)))</span><br><span class="line"></span><br><span class="line">class_correct = <span class="built_in">list</span>(<span class="number">0.</span> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">10</span>))</span><br><span class="line">class_total = <span class="built_in">list</span>(<span class="number">0.</span> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">10</span>))</span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    <span class="keyword">for</span> data <span class="keyword">in</span> test_data:</span><br><span class="line">        images, labels = data</span><br><span class="line">        outputs = net(images)</span><br><span class="line">        _, predicted = torch.<span class="built_in">max</span>(outputs, <span class="number">1</span>)</span><br><span class="line">        c = (predicted == labels).squeeze()</span><br><span class="line">        <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">4</span>):</span><br><span class="line">            label = labels[i]</span><br><span class="line">            class_correct[label] += c[i].item()</span><br><span class="line">            class_total[label] += <span class="number">1</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">10</span>):</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&#x27;Accuracy of %5s : %2d %%&#x27;</span> % (</span><br><span class="line">        classes[i], <span class="number">100</span> * class_correct[i] / class_total[i]))</span><br></pre></td></tr></table></figure><h2 id="第五章-循环神经网络"><a href="#第五章-循环神经网络" class="headerlink" title="第五章 循环神经网络"></a>第五章 循环神经网络</h2><p>RNN，在序列问题和自然语言处理等领域取得很大的成功</p><h3 id="5-1-循环神经网络"><a href="#5-1-循环神经网络" class="headerlink" title="5.1 循环神经网络"></a>5.1 循环神经网络</h3><p>卷积神经网络相当于人类的视觉，但是它没有记忆能力，所以它只能处理一种特定的视觉任务，没办法根据以前的记忆来处理新的任务。</p><p>循环神经网络的提出便是居于记忆模型的想法，期望网络能够记住前面出现的特征，并依据特征推断后面的结果，而且整体的网络结构不断循环，因而得名循环神经网络</p><p>比如：某一个单词的意思会因为上文提到的内容不同而有不同的含义，RNN可以很好的解决这类问题</p><h4 id="5-1-1-问题介绍"><a href="#5-1-1-问题介绍" class="headerlink" title="5.1.1 问题介绍"></a>5.1.1 问题介绍</h4><p>对于下面两句话</p><ul><li>arrive beijing on November 2nd</li><li>leave beijing on November 2nd</li></ul><p>第一句话表达到达，第二句话表示离开，如果网络能构记忆“beijing”前面的词，就会预测出不同的结果。</p><h4 id="5-1-2-循环神经网络的基本结构"><a href="#5-1-2-循环神经网络的基本结构" class="headerlink" title="5.1.2 循环神经网络的基本结构"></a>5.1.2 循环神经网络的基本结构</h4><p>将网络的输出保存在一个记忆单元中，这个记忆单元和下一次的输入一起进入神经网络中。因此，输入序列（sequences）的顺序改变，会改变网络的输出结果。</p><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN/post/v2-206db7ba9d32a80ff56b6cc988a62440_r.jpg"><br><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN/post/v2-b0175ebd3419f9a11a3d0d8b00e28675_r.jpg"></p><p>这个网络在t时刻接收到输入$x_t$之后，隐藏层的值是$S_t$，输出值是$O_t$。关键一点是，$S_t$的值不仅仅取决于$x_t$，还取决于$S_{t-1}$。我们可以用下面的公式来表示循环神经网络的计算方法：</p><p>$$O _ t = g(VS_t)$$<br>$$S _ t = f(UX_t+WS_{t-1})$$</p><h4 id="5-1-3-存在的问题"><a href="#5-1-3-存在的问题" class="headerlink" title="5.1.3 存在的问题"></a>5.1.3 存在的问题</h4><p>循环神经网络具有很好的记忆特性，能够将记忆内容应用到当前情景下，但是记忆最大的问题在于遗忘性</p><h3 id="5-2-循环神经网络的变式：LSTM和GRU"><a href="#5-2-循环神经网络的变式：LSTM和GRU" class="headerlink" title="5.2 循环神经网络的变式：LSTM和GRU"></a>5.2 循环神经网络的变式：LSTM和GRU</h3><h4 id="5-2-1-LSTM"><a href="#5-2-1-LSTM" class="headerlink" title="5.2.1 LSTM"></a>5.2.1 LSTM</h4><p>LSTM是Long Short Term Memory Networks的缩写，是一种链式循环的网络结构，在网络内部有着更复杂的结构，主要为了解决长序列训练过程中的梯度下降和梯度爆炸问题。</p><p>LSTM由三个门来控制，分别是输入门，遗忘门和输出门。顾名思义，输入门控制着网络的输入，遗忘门控制着记忆单元，输出门控制着网络的输出。这其中最重要的就是遗忘门，遗忘门的作用是决定之前的哪些记忆及那个被保留，那些记忆将被去掉，正是由于遗忘门的作用，使得LSTM具有了长时记忆的功能</p><h4 id="5-2-2-GRU"><a href="#5-2-2-GRU" class="headerlink" title="5.2.2 GRU"></a>5.2.2 GRU</h4><p>GRU是Gated Recurrent Unit的缩写，由Cho于2014年提出，GRU和LSTM最大的不同在于GRU将遗忘门和输入门合成了一个“更新门”，同时网络不再额外给出记忆状态Ct，而是将输出结果ht作为记忆状态不断向后循环传递，网络的输出和出入变得简单</p><h4 id="5-2-3-收敛性问题"><a href="#5-2-3-收敛性问题" class="headerlink" title="5.2.3 收敛性问题"></a>5.2.3 收敛性问题</h4><p>如果写了一个简单的LSTM网络去训练数据，会发现loss并不会按照想象的方式下降，而是在乱跳，这是因为RNN的误差曲面粗糙不平导致的，而解决方法是梯度裁剪（gradient clipping）</p><h3 id="5-3-循环神经网络的PyTorch实现"><a href="#5-3-循环神经网络的PyTorch实现" class="headerlink" title="5.3 循环神经网络的PyTorch实现"></a>5.3 循环神经网络的PyTorch实现</h3><h4 id="5-3-1-PyTorch的循环网络模块"><a href="#5-3-1-PyTorch的循环网络模块" class="headerlink" title="5.3.1 PyTorch的循环网络模块"></a>5.3.1 PyTorch的循环网络模块</h4><p><strong>1.标准RNN</strong></p><p><code>nn.RNN()</code><br><strong>参数</strong></p><ul><li><code>input_size</code>表示输入$x_t$的维度</li><li><code>hidden_size</code>表示输出$h_t$的维度</li><li><code>num_layers</code>表示网络层数，默认为1层</li><li><code>nonlinearity</code>表示非线性激活函数，默认为tanh，可选relu</li><li><code>bias</code>表示是否使用偏置，默认为True</li><li><code>batch_first</code>决定网络输入的维度顺序，默认输入顺序（seq,batch,feature），如果设置为True，则顺序为（batch，seq，feature）</li><li><code>dropout</code>，接受一个0到1的数值，并在除最后一层的其他输出层加上dropout层</li><li><code>bidirectional</code>默认是False，如果设置为True，就是双向循环神经网络的结构</li></ul><p><strong>网络接受的输入</strong></p><ul><li>序列输入$x_t$：$x_t$的维度是（seq，batch，feature），分别表示序列长度，批量和输入的特征维度</li><li>记忆输入$h_0$：$h_0$也叫隐藏状态，它的维度是（layers×direction，batch，hidden），分别表示层数乘方向（单向1，双向2），批量和输出的维度</li></ul><p><strong>网络的输出</strong></p><ul><li>output，表示网络实际的输出，维度是（seq，batch，hidden×direction），分别表示序列长度，批量和输出维度乘方向</li><li>$h_n$表示记忆单元，维度是（layer×direction，batch，hidden）分别表示层数乘方向，批量，输出维度</li></ul><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">basic_rnn = nn.RNN(input_size=<span class="number">20</span>,hidden_size=<span class="number">50</span>,num_layers=<span class="number">2</span>)</span><br><span class="line"></span><br><span class="line">toy_input = Variable(torch.randn(<span class="number">100</span>,<span class="number">32</span>,<span class="number">20</span>)) <span class="comment"># seq,batch,input_size</span></span><br><span class="line">h_0 = Variable(torch.rand(<span class="number">2</span>,<span class="number">32</span>,<span class="number">50</span>)) <span class="comment"># layer * direction,batch,hidden_size</span></span><br><span class="line"></span><br><span class="line">toy_output,h_n = basic_rnn(toy_input,h_0)</span><br></pre></td></tr></table></figure><p><strong>2.LSTM</strong></p><p><code>nn.LSTM()</code><br>参数和标准RNN一样</p><p>LSTM与RNN不同的地方：</p><ul><li>LSTM的参数比标准RNN多，是标准RNN维度的4倍，但是访问的方式仍然是相同的</li><li>LSTM的输入还多了一个$C_0$，它们合在一起称为网络的隐藏状态，即（layer×direction，batch，hidden），当然输出也会有$h_0$,$C_0$</li></ul><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">lstm = nn.LSTM(input_size=<span class="number">20</span>,hidden_size=<span class="number">50</span>,num_layers=<span class="number">2</span>)</span><br><span class="line"></span><br><span class="line">lstm_input = Variable(torch.randn(<span class="number">10</span>, <span class="number">3</span>, <span class="number">20</span>))</span><br><span class="line">out, (h, c) = lstm(lstm_input)</span><br></pre></td></tr></table></figure><p><strong>3.GRU</strong></p><p>GRU本质上和LSTM一样</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">gru_seq = nn.GRU(<span class="number">10</span>, <span class="number">20</span>)</span><br><span class="line">gru_input = Variable(torch.randn(<span class="number">3</span>, <span class="number">32</span>, <span class="number">10</span>))</span><br><span class="line"></span><br><span class="line">out, h = gru_seq(gru_input)</span><br></pre></td></tr></table></figure><p>它和LSTM不同的地方：</p><ul><li>参数是标准RNN的三倍</li><li>网络的隐藏状态只有h0</li></ul><h4 id="5-3-2-实例介绍"><a href="#5-3-2-实例介绍" class="headerlink" title="5.3.2 实例介绍"></a>5.3.2 实例介绍</h4><p>序列预测</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torch.nn <span class="keyword">as</span> nn</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line">%matplotlib inline</span><br><span class="line"></span><br><span class="line"><span class="comment">#希望通过前两个月的流量来预测当月的流量</span></span><br><span class="line"><span class="comment">#将前两个月的流量当做输入，当月的流量当做输出</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">create_dataset</span>(<span class="params">dataset,look_back=<span class="number">2</span></span>):</span></span><br><span class="line">    dataX,dataY = [],[]</span><br><span class="line">    <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(dataset)-look_back):</span><br><span class="line">        a = dataset[i:(i+look_back)]</span><br><span class="line">        dataX.append(a)</span><br><span class="line">        dataY.append(dataset[i+look_back])</span><br><span class="line">    <span class="keyword">return</span> np.array(dataX),np.array(dataY)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 定义模型</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">lstm_reg</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, input_size, hidden_size, output_size=<span class="number">1</span>, num_layers=<span class="number">2</span></span>):</span></span><br><span class="line">        <span class="built_in">super</span>(lstm_reg, self).__init__()</span><br><span class="line">        </span><br><span class="line">        self.rnn = nn.LSTM(input_size, hidden_size, num_layers) <span class="comment"># rnn</span></span><br><span class="line">        self.reg = nn.Linear(hidden_size, output_size) <span class="comment"># 回归</span></span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        x, _ = self.rnn(x) <span class="comment"># (seq, batch, hidden)</span></span><br><span class="line">        s, b, h = x.shape</span><br><span class="line">        x = x.view(s*b, h) <span class="comment"># 转换成线性层的输入格式</span></span><br><span class="line">        x = self.reg(x)</span><br><span class="line">        x = x.view(s, b, -<span class="number">1</span>)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line"><span class="comment">#读取数据</span></span><br><span class="line">data_csv = pd.read_csv(<span class="string">&#x27;./data.csv&#x27;</span>, usecols=[<span class="number">1</span>])</span><br><span class="line"></span><br><span class="line"><span class="comment"># 预处理，将数据中na的数据去掉，然后将数据标准化到0~1之间</span></span><br><span class="line">data_csv = data_csv.dropna()</span><br><span class="line">dataset = data_csv.values</span><br><span class="line">dataset = dataset.astype(<span class="string">&#x27;float32&#x27;</span>)</span><br><span class="line">max_value = np.<span class="built_in">max</span>(dataset)</span><br><span class="line">min_value = np.<span class="built_in">min</span>(dataset)</span><br><span class="line">scalar = max_value - min_value</span><br><span class="line">dataset = <span class="built_in">list</span>(<span class="built_in">map</span>(<span class="keyword">lambda</span> x: x / scalar, dataset))</span><br><span class="line"></span><br><span class="line"><span class="comment"># 创建好输入输出</span></span><br><span class="line">data_X, data_Y = create_dataset(dataset)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 划分训练集和测试集，70% 作为训练集</span></span><br><span class="line">train_size = <span class="built_in">int</span>(<span class="built_in">len</span>(data_X) * <span class="number">0.7</span>)</span><br><span class="line">test_size = <span class="built_in">len</span>(data_X) - train_size</span><br><span class="line">train_X = data_X[:train_size]</span><br><span class="line">train_Y = data_Y[:train_size]</span><br><span class="line">test_X = data_X[train_size:]</span><br><span class="line">test_Y = data_Y[train_size:]</span><br><span class="line"></span><br><span class="line"><span class="comment">#将数据改变一下形状 (seq, batch, feature)</span></span><br><span class="line"><span class="comment">#只有一个序列，所以 batch 是 1</span></span><br><span class="line"><span class="comment">#输入的feature是希望依据的几个月份，这里定的是两个月份，feature=2.</span></span><br><span class="line">train_X = train_X.reshape(-<span class="number">1</span>, <span class="number">1</span>, <span class="number">2</span>)</span><br><span class="line">train_Y = train_Y.reshape(-<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>)</span><br><span class="line">test_X = test_X.reshape(-<span class="number">1</span>, <span class="number">1</span>, <span class="number">2</span>)</span><br><span class="line"></span><br><span class="line">train_x = torch.from_numpy(train_X)</span><br><span class="line">train_y = torch.from_numpy(train_Y)</span><br><span class="line">test_x = torch.from_numpy(test_X)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 定义损失和优化</span></span><br><span class="line">net = lstm_reg(<span class="number">2</span>, <span class="number">4</span>)</span><br><span class="line">criterion = nn.MSELoss()</span><br><span class="line">optimizer = torch.optim.Adam(net.parameters(), lr=<span class="number">1e-2</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 开始训练</span></span><br><span class="line"><span class="keyword">for</span> e <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">1000</span>):</span><br><span class="line">    var_x = Variable(train_x)</span><br><span class="line">    var_y = Variable(train_y)</span><br><span class="line">    <span class="comment"># 前向传播</span></span><br><span class="line">    out = net(var_x)</span><br><span class="line">    loss = criterion(out, var_y)</span><br><span class="line">    <span class="comment"># 反向传播</span></span><br><span class="line">    optimizer.zero_grad()</span><br><span class="line">    loss.backward()</span><br><span class="line">    optimizer.step()</span><br><span class="line">    <span class="keyword">if</span> (e + <span class="number">1</span>) % <span class="number">100</span> == <span class="number">0</span>: <span class="comment"># 每 100 次输出结果</span></span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;Epoch: &#123;&#125;, Loss: &#123;:.5f&#125;&#x27;</span>.<span class="built_in">format</span>(e + <span class="number">1</span>, loss.data))</span><br><span class="line">        </span><br><span class="line"><span class="comment">#测试</span></span><br><span class="line">net = net.<span class="built_in">eval</span>() <span class="comment"># 转换成测试模式</span></span><br><span class="line">data_X = data_X.reshape(-<span class="number">1</span>, <span class="number">1</span>, <span class="number">2</span>)</span><br><span class="line">data_X = torch.from_numpy(data_X)</span><br><span class="line">var_data = Variable(data_X)</span><br><span class="line">pred_test = net(var_data) <span class="comment"># 测试集的预测结果</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 改变输出的格式</span></span><br><span class="line">pred_test = pred_test.view(-<span class="number">1</span>).data.numpy()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 画出实际结果和预测的结果</span></span><br><span class="line">plt.plot(pred_test, <span class="string">&#x27;r&#x27;</span>, label=<span class="string">&#x27;prediction&#x27;</span>)</span><br><span class="line">plt.plot(dataset, <span class="string">&#x27;b&#x27;</span>, label=<span class="string">&#x27;real&#x27;</span>)</span><br><span class="line">plt.legend(loc=<span class="string">&#x27;best&#x27;</span>)</span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20201102181225.png"></p><h3 id="5-4-自然语言处理的应用"><a href="#5-4-自然语言处理的应用" class="headerlink" title="5.4 自然语言处理的应用"></a>5.4 自然语言处理的应用</h3><h4 id="5-4-1-词嵌入"><a href="#5-4-1-词嵌入" class="headerlink" title="5.4.1 词嵌入"></a>5.4.1 词嵌入</h4><p>词嵌入（word embedding），也称为词向量，即对于每个词，可以使用一个高维向量去表示它</p><p>例如：</p><ul><li>(1)The cat likes playing ball</li><li>(2)The kitty likes playing wool</li><li>(3)The dog likes playing ball</li><li>(4)The boy doesn’t like playing ball</li></ul><p>对于这四句话里的四个词，cat，kitty，dog，boy，如果用one-hot编码，那么cat可以是（1，0，0，0），kitty可以是（0，1，0，0），但是cat和kitty都是小猫，所以这两个词实际语义是接近的，但是one-hot不能体现这个特点，于是可以用词嵌入的方式表示这四个词。</p><p>假设使用一个二维向量（a，b）来表示一个词，其中a代表是否喜欢玩球，b代表是否喜欢玩毛线，且数值越大代表越喜欢，那么对于cat可以表示（-1，4），对于kitty可以表示为（-2，5），对于dog可以表示为（3，-2），对于boy可以表示为（-2，-3）</p><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20201102190952.png"></p><p>可以发现kitty和cat的夹角更小，所以它们更加相似</p><h4 id="5-4-2-词嵌入的PyTorch实现"><a href="#5-4-2-词嵌入的PyTorch实现" class="headerlink" title="5.4.2 词嵌入的PyTorch实现"></a>5.4.2 词嵌入的PyTorch实现</h4><p>PyTorch中的词嵌入是通过函数<code>nn.Embedding(m,n)</code>来实现的，其中m表示所有的单词数目，n表示词嵌入的维度</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">word_to_ix = &#123;<span class="string">&#x27;hello&#x27;</span>:<span class="number">0</span>,<span class="string">&#x27;world&#x27;</span>:<span class="number">1</span>&#125;</span><br><span class="line">embeds = nn.Embeding(<span class="number">2</span>,<span class="number">5</span>)</span><br><span class="line">hello_idx = torch.LongTensor([word_to_ix[<span class="string">&#x27;hello&#x27;</span>]])</span><br><span class="line">hello_idx = Variable(hello_idx)</span><br><span class="line">hello_embed = embeds(hello_idx)</span><br><span class="line"><span class="built_in">print</span>(hello_embed)</span><br></pre></td></tr></table></figure><h4 id="5-4-3-N-Gram模型"><a href="#5-4-3-N-Gram模型" class="headerlink" title="5.4.3 N Gram模型"></a>5.4.3 N Gram模型</h4><p>对于一句话，单词的排列顺序是非常重要的，所以我们能否由前面的几个词来预测后面的几个单词呢，比如 ‘I lived in France for 10 years, I can speak _ ‘ 这句话中，我们能够预测出最后一个词是 French。</p><p>对于一句话T，它由w1，w2,…,wn这n个词构成，可以得到下面的公式<br>$$<br>P(T) = P(w_1)P(w_2 | w_1)P(w_3 |w_2 w_1) \cdots P(w_n |w_{n-1} w_{n-2}\cdots w_2w_1)<br>$$<br>但是该模型存在如参数空间过大等缺陷，因此引入了马尔科夫假设，也就是说这个单词只与前面的几个词有关系。</p><p>对于这个条件概率，传统的方式是统计语料中每个单词出现的频率，据此来估计这个条件概率，这里使用词嵌入的办法，直接在语料中计算这个条件概率，然后最大化条件概率从而优化词向量，据此进行预测</p><h4 id="5-4-4-单词预测的PyTorch实现"><a href="#5-4-4-单词预测的PyTorch实现" class="headerlink" title="5.4.4 单词预测的PyTorch实现"></a>5.4.4 单词预测的PyTorch实现</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">import</span> torch.nn.functional <span class="keyword">as</span> F</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"></span><br><span class="line">CONTEXT_SIZE = <span class="number">2</span> <span class="comment"># 依据的单词数</span></span><br><span class="line">EMBEDDING_DIM = <span class="number">10</span> <span class="comment"># 词向量的维度</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 定义模型</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">n_gram</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, vocab_size, context_size=CONTEXT_SIZE, n_dim=EMBEDDING_DIM</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(n_gram, self).__init__()</span><br><span class="line">        </span><br><span class="line">        self.embed = nn.Embedding(vocab_size, n_dim)</span><br><span class="line">        self.classify = nn.Sequential(</span><br><span class="line">            nn.Linear(context_size * n_dim, <span class="number">128</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Linear(<span class="number">128</span>, vocab_size)</span><br><span class="line">        )</span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        voc_embed = self.embed(x) <span class="comment"># 得到词嵌入</span></span><br><span class="line">        voc_embed = voc_embed.view(<span class="number">1</span>, -<span class="number">1</span>) <span class="comment"># 将两个词向量拼在一起</span></span><br><span class="line">        out = self.classify(voc_embed)</span><br><span class="line">        <span class="keyword">return</span> out</span><br><span class="line"></span><br><span class="line"><span class="comment"># 我们使用莎士比亚的诗</span></span><br><span class="line">test_sentence = <span class="string">&quot;&quot;&quot;When forty winters shall besiege thy brow,</span></span><br><span class="line"><span class="string">And dig deep trenches in thy beauty&#x27;s field,</span></span><br><span class="line"><span class="string">Thy youth&#x27;s proud livery so gazed on now,</span></span><br><span class="line"><span class="string">Will be a totter&#x27;d weed of small worth held:</span></span><br><span class="line"><span class="string">Then being asked, where all thy beauty lies,</span></span><br><span class="line"><span class="string">Where all the treasure of thy lusty days;</span></span><br><span class="line"><span class="string">To say, within thine own deep sunken eyes,</span></span><br><span class="line"><span class="string">Were an all-eating shame, and thriftless praise.</span></span><br><span class="line"><span class="string">How much more praise deserv&#x27;d thy beauty&#x27;s use,</span></span><br><span class="line"><span class="string">If thou couldst answer &#x27;This fair child of mine</span></span><br><span class="line"><span class="string">Shall sum my count, and make my old excuse,&#x27;</span></span><br><span class="line"><span class="string">Proving his beauty by succession thine!</span></span><br><span class="line"><span class="string">This were to be new made when thou art old,</span></span><br><span class="line"><span class="string">And see thy blood warm when thou feel&#x27;st it cold.&quot;&quot;&quot;</span>.split()</span><br><span class="line"></span><br><span class="line">trigram = [((test_sentence[i], test_sentence[i+<span class="number">1</span>]), test_sentence[i+<span class="number">2</span>]) </span><br><span class="line">            <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(test_sentence)-<span class="number">2</span>)]</span><br><span class="line"></span><br><span class="line"><span class="comment"># 建立每个词与数字的编码，据此构建词嵌入</span></span><br><span class="line">vocb = <span class="built_in">set</span>(test_sentence) <span class="comment"># 使用 set 将重复的元素去掉</span></span><br><span class="line">word_to_idx = &#123;word: i <span class="keyword">for</span> i, word <span class="keyword">in</span> <span class="built_in">enumerate</span>(vocb)&#125;</span><br><span class="line">idx_to_word = &#123;word_to_idx[word]: word <span class="keyword">for</span> word <span class="keyword">in</span> word_to_idx&#125;</span><br><span class="line"></span><br><span class="line">net = n_gram(<span class="built_in">len</span>(word_to_idx))</span><br><span class="line">criterion = nn.CrossEntropyLoss()</span><br><span class="line">optimizer = torch.optim.SGD(net.parameters(), lr=<span class="number">1e-2</span>, weight_decay=<span class="number">1e-5</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 开始训练</span></span><br><span class="line"><span class="keyword">for</span> e <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">100</span>):</span><br><span class="line">    train_loss = <span class="number">0</span></span><br><span class="line">    <span class="keyword">for</span> word, label <span class="keyword">in</span> trigram: <span class="comment"># 使用前 100 个作为训练集</span></span><br><span class="line">        word = Variable(torch.LongTensor([word_to_idx[i] <span class="keyword">for</span> i <span class="keyword">in</span> word])) <span class="comment"># 将两个词作为输入</span></span><br><span class="line">        label = Variable(torch.LongTensor([word_to_idx[label]]))</span><br><span class="line">        <span class="comment"># 前向传播</span></span><br><span class="line">        out = net(word)</span><br><span class="line">        loss = criterion(out, label)</span><br><span class="line">        train_loss += loss.data</span><br><span class="line">        <span class="comment"># 反向传播</span></span><br><span class="line">        optimizer.zero_grad()</span><br><span class="line">        loss.backward()</span><br><span class="line">        optimizer.step()</span><br><span class="line">    <span class="keyword">if</span> (e + <span class="number">1</span>) % <span class="number">20</span> == <span class="number">0</span>:</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;epoch: &#123;&#125;, Loss: &#123;:.6f&#125;&#x27;</span>.<span class="built_in">format</span>(e + <span class="number">1</span>, train_loss / <span class="built_in">len</span>(trigram)))</span><br><span class="line">        </span><br><span class="line"><span class="comment"># 测试</span></span><br><span class="line">word, label = trigram[<span class="number">19</span>]</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&#x27;input: &#123;&#125;&#x27;</span>.<span class="built_in">format</span>(word))</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&#x27;label: &#123;&#125;&#x27;</span>.<span class="built_in">format</span>(label))</span><br><span class="line"></span><br><span class="line">word = Variable(torch.LongTensor([word_to_idx[i] <span class="keyword">for</span> i <span class="keyword">in</span> word]))</span><br><span class="line">out = net(word)</span><br><span class="line">pred_label_idx = out.<span class="built_in">max</span>(<span class="number">1</span>)[<span class="number">1</span>].item()</span><br><span class="line">predict_word = idx_to_word[pred_label_idx]</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&#x27;real word is &#123;&#125;, predicted word is &#123;&#125;&#x27;</span>.<span class="built_in">format</span>(label, predict_word))</span><br></pre></td></tr></table></figure><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">epoch: 20, Loss: 0.873597</span><br><span class="line">epoch: 40, Loss: 0.153170</span><br><span class="line">epoch: 60, Loss: 0.090456</span><br><span class="line">epoch: 80, Loss: 0.071410</span><br><span class="line">epoch: 100, Loss: 0.061979</span><br><span class="line">input: (&#x27;so&#x27;, &#x27;gazed&#x27;)</span><br><span class="line">label: on</span><br><span class="line"></span><br><span class="line">real word is on, predicted word is on</span><br></pre></td></tr></table></figure><h4 id="5-4-5-词性判断"><a href="#5-4-5-词性判断" class="headerlink" title="5.4.5 词性判断"></a>5.4.5 词性判断</h4><p><strong>1.LSTM做词性判断的基本原理</strong></p><p>同构LSTM，根据它记忆的特性，能够通过这个单词前面记忆的一些词语来对它做一个判断，比如前面的单词如果是my，那么紧跟的词很可能是一个名词，这样就能充分利用上文来处理这个问题</p><p><strong>2.字符增强</strong></p><p>通过引入字符来增强表达，比如有些单词存在前缀或者后缀，比如<code>-ly</code>这种后缀很有可能是副词，这样我们就能在字符水平对词性进一步判断，把两种方法集成起来，能够得到一个更好的结果</p><h4 id="5-4-6-词性判断的PyTorch实现"><a href="#5-4-6-词性判断的PyTorch实现" class="headerlink" title="5.4.6 词性判断的PyTorch实现"></a>5.4.6 词性判断的PyTorch实现</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"></span><br><span class="line">training_data = [(<span class="string">&quot;The dog ate the apple&quot;</span>.split(),</span><br><span class="line">                  [<span class="string">&quot;DET&quot;</span>, <span class="string">&quot;NN&quot;</span>, <span class="string">&quot;V&quot;</span>, <span class="string">&quot;DET&quot;</span>, <span class="string">&quot;NN&quot;</span>]),</span><br><span class="line">                 (<span class="string">&quot;Everybody read that book&quot;</span>.split(), </span><br><span class="line">                  [<span class="string">&quot;NN&quot;</span>, <span class="string">&quot;V&quot;</span>, <span class="string">&quot;DET&quot;</span>, <span class="string">&quot;NN&quot;</span>])]</span><br><span class="line"></span><br><span class="line"><span class="comment">#对单词和标签进行编码</span></span><br><span class="line">word_to_idx = &#123;&#125;</span><br><span class="line">tag_to_idx = &#123;&#125;</span><br><span class="line"><span class="keyword">for</span> context, tag <span class="keyword">in</span> training_data:</span><br><span class="line">    <span class="keyword">for</span> word <span class="keyword">in</span> context:</span><br><span class="line">        <span class="keyword">if</span> word.lower() <span class="keyword">not</span> <span class="keyword">in</span> word_to_idx:</span><br><span class="line">            word_to_idx[word.lower()] = <span class="built_in">len</span>(word_to_idx)</span><br><span class="line">    <span class="keyword">for</span> label <span class="keyword">in</span> tag:</span><br><span class="line">        <span class="keyword">if</span> label.lower() <span class="keyword">not</span> <span class="keyword">in</span> tag_to_idx:</span><br><span class="line">            tag_to_idx[label.lower()] = <span class="built_in">len</span>(tag_to_idx)</span><br><span class="line"></span><br><span class="line"><span class="comment">#对字母编码</span></span><br><span class="line">alphabet = <span class="string">&#x27;abcdefghijklmnopqrstuvwxyz&#x27;</span></span><br><span class="line">char_to_idx = &#123;&#125;</span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(alphabet)):</span><br><span class="line">    char_to_idx[alphabet[i]] = i</span><br><span class="line">    </span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">make_sequence</span>(<span class="params">x, dic</span>):</span> <span class="comment"># 字符编码</span></span><br><span class="line">    idx = [dic[i.lower()] <span class="keyword">for</span> i <span class="keyword">in</span> x]</span><br><span class="line">    idx = torch.LongTensor(idx)</span><br><span class="line">    <span class="keyword">return</span> idx</span><br><span class="line"></span><br><span class="line"><span class="comment">#构建单个字符的lstm模型</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">char_lstm</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, n_char, char_dim, char_hidden</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(char_lstm, self).__init__()</span><br><span class="line">        </span><br><span class="line">        self.char_embed = nn.Embedding(n_char, char_dim)</span><br><span class="line">        self.lstm = nn.LSTM(char_dim, char_hidden)</span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        x = self.char_embed(x)</span><br><span class="line">        out, _ = self.lstm(x)</span><br><span class="line">        <span class="keyword">return</span> out[-<span class="number">1</span>] <span class="comment"># (batch, hidden)</span></span><br><span class="line"></span><br><span class="line"><span class="comment">#构建词性分类的lstm模型</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">lstm_tagger</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, n_word, n_char, char_dim, word_dim, </span></span></span><br><span class="line"><span class="params"><span class="function">                 char_hidden, word_hidden, n_tag</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(lstm_tagger, self).__init__()</span><br><span class="line">        self.word_embed = nn.Embedding(n_word, word_dim)</span><br><span class="line">        self.char_lstm = char_lstm(n_char, char_dim, char_hidden)</span><br><span class="line">        self.word_lstm = nn.LSTM(word_dim + char_hidden, word_hidden)</span><br><span class="line">        self.classify = nn.Linear(word_hidden, n_tag)</span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self, x, word</span>):</span></span><br><span class="line">        char = []</span><br><span class="line">        <span class="keyword">for</span> w <span class="keyword">in</span> word: <span class="comment"># 对于每个单词做字符的 lstm</span></span><br><span class="line">            char_list = make_sequence(w, char_to_idx)</span><br><span class="line">            char_list = char_list.unsqueeze(<span class="number">1</span>) <span class="comment"># (seq, batch, feature) 满足 lstm 输入条件</span></span><br><span class="line">            char_infor = self.char_lstm(Variable(char_list)) <span class="comment"># (batch, char_hidden)</span></span><br><span class="line">            char.append(char_infor)</span><br><span class="line">        char = torch.stack(char, dim=<span class="number">0</span>) <span class="comment"># (seq, batch, feature)</span></span><br><span class="line">        </span><br><span class="line">        x = self.word_embed(x) <span class="comment"># (batch, seq, word_dim)</span></span><br><span class="line">        x = x.permute(<span class="number">1</span>, <span class="number">0</span>, <span class="number">2</span>) <span class="comment"># 改变顺序</span></span><br><span class="line">        x = torch.cat((x, char), dim=<span class="number">2</span>) <span class="comment"># 沿着特征通道将每个词的词嵌入和字符 lstm 输出的结果拼接在一起</span></span><br><span class="line">        x, _ = self.word_lstm(x)</span><br><span class="line">        </span><br><span class="line">        s, b, h = x.shape</span><br><span class="line">        x = x.view(-<span class="number">1</span>, h) <span class="comment"># 重新 reshape 进行分类线性层</span></span><br><span class="line">        out = self.classify(x)</span><br><span class="line">        <span class="keyword">return</span> out</span><br><span class="line"></span><br><span class="line">net = lstm_tagger(<span class="built_in">len</span>(word_to_idx), <span class="built_in">len</span>(char_to_idx), <span class="number">10</span>, <span class="number">100</span>, <span class="number">50</span>, <span class="number">128</span>, <span class="built_in">len</span>(tag_to_idx))</span><br><span class="line">criterion = nn.CrossEntropyLoss()</span><br><span class="line">optimizer = torch.optim.SGD(net.parameters(), lr=<span class="number">1e-2</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 开始训练</span></span><br><span class="line"><span class="keyword">for</span> e <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">300</span>):</span><br><span class="line">    train_loss = <span class="number">0</span></span><br><span class="line">    <span class="keyword">for</span> word, tag <span class="keyword">in</span> training_data:</span><br><span class="line">        word_list = make_sequence(word, word_to_idx).unsqueeze(<span class="number">0</span>) <span class="comment"># 添加第一维 batch</span></span><br><span class="line">        tag = make_sequence(tag, tag_to_idx)</span><br><span class="line">        word_list = Variable(word_list)</span><br><span class="line">        tag = Variable(tag)</span><br><span class="line">        <span class="comment"># 前向传播</span></span><br><span class="line">        out = net(word_list, word)</span><br><span class="line">        loss = criterion(out, tag)</span><br><span class="line">        train_loss += loss.data</span><br><span class="line">        <span class="comment"># 反向传播</span></span><br><span class="line">        optimizer.zero_grad()</span><br><span class="line">        loss.backward()</span><br><span class="line">        optimizer.step()</span><br><span class="line">    <span class="keyword">if</span> (e + <span class="number">1</span>) % <span class="number">50</span> == <span class="number">0</span>:</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;Epoch: &#123;&#125;, Loss: &#123;:.5f&#125;&#x27;</span>.<span class="built_in">format</span>(e + <span class="number">1</span>, train_loss / <span class="built_in">len</span>(training_data)))</span><br><span class="line"></span><br><span class="line"><span class="comment">#测试</span></span><br><span class="line">net = net.<span class="built_in">eval</span>()</span><br><span class="line">test_sent = <span class="string">&#x27;Everybody ate the apple&#x27;</span></span><br><span class="line">test = make_sequence(test_sent.split(), word_to_idx).unsqueeze(<span class="number">0</span>)</span><br><span class="line">out = net(Variable(test), test_sent.split())</span><br><span class="line"><span class="built_in">print</span>(out)</span><br><span class="line"><span class="built_in">print</span>(tag_to_idx)</span><br></pre></td></tr></table></figure><h3 id="5-5-循环神经网络的更多应用"><a href="#5-5-循环神经网络的更多应用" class="headerlink" title="5.5 循环神经网络的更多应用"></a>5.5 循环神经网络的更多应用</h3><h4 id="5-5-1-Many-to-one"><a href="#5-5-1-Many-to-one" class="headerlink" title="5.5.1 Many to one"></a>5.5.1 Many to one</h4><p>循环神经网络不仅能够输入序列，输出序列，还能后输入序列，输出单个向量。只需要再输出的序列里面取其中一个就可以，通常是取最后一个。这样的结构被称为Many to one。</p><p>Many to one的结构可以用来执行什么任务：</p><ul><li>情感分析</li><li>关键字提取</li></ul><h4 id="5-5-2-Many-to-Many-shorter"><a href="#5-5-2-Many-to-Many-shorter" class="headerlink" title="5.5.2 Many to Many (shorter)"></a>5.5.2 Many to Many (shorter)</h4><p>这种结构是输入和输出都是序列，但是输出的序列比输入的序列短。这种类型的结构通常在语音识别中遇到，因为一段话如果用语言表达往往会比这段话更长。这种情况需要使用CTC算法解决重复的问题，CTC就是将输出的所有可能列举出来，然后通过去重复，去空格的方式来选择最大的概率。</p><h4 id="5-5-3-Seq2seq"><a href="#5-5-3-Seq2seq" class="headerlink" title="5.5.3 Seq2seq"></a>5.5.3 Seq2seq</h4><p>这种情况是输出的长度不确定，一般是在机器翻译的任务中出现。</p><h4 id="5-5-4-CNN-RNN"><a href="#5-5-4-CNN-RNN" class="headerlink" title="5.5.4 CNN+RNN"></a>5.5.4 CNN+RNN</h4><p>RNN和CNN可以联合在一起完成图像描述任务，简而言之，就是通过预训练的卷积神经网络提取图片特征，接着通过循环网络将特征变成文字描述</p><h2 id="第6章-生成对抗网络"><a href="#第6章-生成对抗网络" class="headerlink" title="第6章 生成对抗网络"></a>第6章 生成对抗网络</h2><p>2014年，lan Goodfellow提出的生成对抗网络（Generative Adversarial Networks，GANs）推进了整个无监督学习的发展进程，让机器实现一些创造性工作，如画画，写诗，创作歌词等成为可能…</p><h3 id="6-1-生成模型"><a href="#6-1-生成模型" class="headerlink" title="6.1 生成模型"></a>6.1 生成模型</h3><p>生成模型(Generative Model)这一概念属于概率统计和机器学习,是指一系列用于随机生成可观测数据的模型.简而言之,就是”生成”的样本和”真实”的样本尽可能地相似.</p><p>生成模型的两个主要功能就是学习一个概率分布$P_{model}(x)$和生成数据</p><h4 id="6-1-1-自动编码器"><a href="#6-1-1-自动编码器" class="headerlink" title="6.1.1 自动编码器"></a>6.1.1 自动编码器</h4><p>自动编码器(AutoEncoder)最开始作为一种数据的压缩方法,其特点有:</p><ul><li>和数据相关程度很高</li><li>压缩后数据是有损的</li></ul><p>所以现在自动编码器主要应用在几个方面:</p><ul><li>数据去噪</li><li>可视化降维</li><li>生成数据</li></ul><p>自动编码器的一般结构</p><ul><li>编码器(Encoder)</li><li>解码器(Decoder)</li></ul><p>编码器和解码器可以是任意的模型,通常使用神经网络模型作为编码器和解码器.输入的数据经过神经网络降维到一个编码(code),接着又通过另一个神经网络去解码得到一个与输入原数据一模一样的生成数据,然后通过比较这两个数据,最小化它们之间的差异来训练这个网络中编码器和解码器的参数.当这个过程训练完之后,拿出这个解码器,随机传入一个编码,通过解码器能够生成一个和原数据差不多的数据</p><p>下面我们使用 mnist 数据集来说明一个如何构建一个简单的自动编码器</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> os</span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torchvision.datasets <span class="keyword">import</span> MNIST</span><br><span class="line"><span class="keyword">from</span> torchvision <span class="keyword">import</span> transforms <span class="keyword">as</span> tfs</span><br><span class="line"><span class="keyword">from</span> torchvision.utils <span class="keyword">import</span> save_image</span><br><span class="line"></span><br><span class="line"><span class="comment">#进行数据预处理和迭代器的构建</span></span><br><span class="line">im_tfs = tfs.Compose([</span><br><span class="line">    tfs.ToTensor(),</span><br><span class="line">    tfs.Normalize([<span class="number">0.5</span>], [<span class="number">0.5</span>]) <span class="comment"># 标准化</span></span><br><span class="line">])</span><br><span class="line"></span><br><span class="line">train_set = MNIST(<span class="string">&#x27;./data&#x27;</span>, train=<span class="literal">True</span>,transform=im_tfs,download=<span class="literal">True</span>)</span><br><span class="line">train_data = DataLoader(train_set, batch_size=<span class="number">128</span>, shuffle=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment">#定义网络</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">autoencoder</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(autoencoder,self).__init__()</span><br><span class="line">        self.encoder = nn.Sequential(</span><br><span class="line">            nn.Linear(<span class="number">28</span>*<span class="number">28</span>,<span class="number">128</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Linear(<span class="number">128</span>,<span class="number">64</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Linear(<span class="number">64</span>,<span class="number">12</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Linear(<span class="number">12</span>,<span class="number">3</span>) <span class="comment"># 输出的 code 是 3 维，便于可视化</span></span><br><span class="line">        )</span><br><span class="line">        self.decoder = nn.Sequential(</span><br><span class="line">            nn.Linear(<span class="number">3</span>,<span class="number">12</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Linear(<span class="number">12</span>,<span class="number">64</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Linear(<span class="number">64</span>,<span class="number">128</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.Linear(<span class="number">128</span>,<span class="number">28</span>*<span class="number">28</span>),</span><br><span class="line">            nn.Tanh()</span><br><span class="line">        )</span><br><span class="line">    </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self,x</span>):</span></span><br><span class="line">        encode = self.encoder(x)</span><br><span class="line">        decode = self.decoder(encode)</span><br><span class="line">        <span class="keyword">return</span> encode,decode</span><br><span class="line"><span class="string">&quot;&quot;&quot;</span></span><br><span class="line"><span class="string">这里定义的编码器和解码器都是 4 层神经网络作为模型，</span></span><br><span class="line"><span class="string">中间使用 relu 激活函数，最后输出的 code 是三维，</span></span><br><span class="line"><span class="string">注意解码器最后我们使用tanh作为激活函数，</span></span><br><span class="line"><span class="string">因为输入图片标准化在 -1 ~ 1 之间，</span></span><br><span class="line"><span class="string">所以输出也要在 -1 ~ 1 这个范围内</span></span><br><span class="line"><span class="string">&quot;&quot;&quot;</span></span><br><span class="line">net = autoencoder()</span><br><span class="line">criterion = nn.MSELoss(size_average=<span class="literal">False</span>)</span><br><span class="line">optimizer = torch.optim.Adam(net.parameters(), lr=<span class="number">1e-3</span>)</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">to_img</span>(<span class="params">x</span>):</span></span><br><span class="line">    <span class="comment"># 定义一个函数将最后的结果转换回图片</span></span><br><span class="line">    x = <span class="number">0.5</span> * (x + <span class="number">1.</span>)</span><br><span class="line">    x = x.clamp(<span class="number">0</span>, <span class="number">1</span>)</span><br><span class="line">    x = x.view(x.shape[<span class="number">0</span>], <span class="number">1</span>, <span class="number">28</span>, <span class="number">28</span>)</span><br><span class="line">    <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line"><span class="comment"># 开始训练自动编码器</span></span><br><span class="line"><span class="keyword">for</span> e <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">100</span>):</span><br><span class="line">    <span class="keyword">for</span> im, _ <span class="keyword">in</span> train_data:</span><br><span class="line">        im = im.view(im.shape[<span class="number">0</span>], -<span class="number">1</span>)</span><br><span class="line">        im = Variable(im)</span><br><span class="line">        <span class="comment"># 前向传播</span></span><br><span class="line">        _, output = net(im)</span><br><span class="line">        loss = criterion(output, im) / im.shape[<span class="number">0</span>] <span class="comment"># 平均</span></span><br><span class="line">        <span class="comment"># 反向传播</span></span><br><span class="line">        optimizer.zero_grad()</span><br><span class="line">        loss.backward()</span><br><span class="line">        optimizer.step()</span><br><span class="line">    </span><br><span class="line">    <span class="keyword">if</span> (e+<span class="number">1</span>) % <span class="number">20</span> == <span class="number">0</span>: <span class="comment"># 每 20 次，将生成的图片保存一下</span></span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;epoch: &#123;&#125;, Loss: &#123;:.4f&#125;&#x27;</span>.<span class="built_in">format</span>(e + <span class="number">1</span>, loss.data))</span><br><span class="line">        pic = to_img(output.cpu().data)</span><br><span class="line">        <span class="keyword">if</span> <span class="keyword">not</span> os.path.exists(<span class="string">&#x27;./simple_autoencoder&#x27;</span>):</span><br><span class="line">            os.mkdir(<span class="string">&#x27;./simple_autoencoder&#x27;</span>)</span><br><span class="line">        save_image(pic, <span class="string">&#x27;./simple_autoencoder/image_&#123;&#125;.png&#x27;</span>.<span class="built_in">format</span>(e + <span class="number">1</span>))</span><br></pre></td></tr></table></figure><p>训练完成之后看看效果</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line"><span class="keyword">from</span> matplotlib <span class="keyword">import</span> cm</span><br><span class="line"><span class="keyword">from</span> mpl_toolkits.mplot3d <span class="keyword">import</span> Axes3D</span><br><span class="line">%matplotlib inline</span><br><span class="line"></span><br><span class="line"><span class="comment"># 可视化结果</span></span><br><span class="line">view_data = Variable((train_set.train_data[:<span class="number">200</span>].<span class="built_in">type</span>(torch.FloatTensor).view(-<span class="number">1</span>, <span class="number">28</span>*<span class="number">28</span>) / <span class="number">255.</span> - <span class="number">0.5</span>) / <span class="number">0.5</span>)</span><br><span class="line">encode, _ = net(view_data)    <span class="comment"># 提取压缩的特征值</span></span><br><span class="line">fig = plt.figure(<span class="number">2</span>)</span><br><span class="line">ax = Axes3D(fig)    <span class="comment"># 3D 图</span></span><br><span class="line"><span class="comment"># x, y, z 的数据值</span></span><br><span class="line">X = encode.data[:, <span class="number">0</span>].numpy()</span><br><span class="line">Y = encode.data[:, <span class="number">1</span>].numpy()</span><br><span class="line">Z = encode.data[:, <span class="number">2</span>].numpy()</span><br><span class="line">values = train_set.train_labels[:<span class="number">200</span>].numpy()  <span class="comment"># 标签值</span></span><br><span class="line"><span class="keyword">for</span> x, y, z, s <span class="keyword">in</span> <span class="built_in">zip</span>(X, Y, Z, values):</span><br><span class="line">    c = cm.rainbow(<span class="built_in">int</span>(<span class="number">255</span>*s/<span class="number">9</span>))    <span class="comment"># 上色</span></span><br><span class="line">    ax.text(x, y, z, s, backgroundcolor=c)  <span class="comment"># 标位子</span></span><br><span class="line">ax.set_xlim(X.<span class="built_in">min</span>(), X.<span class="built_in">max</span>())</span><br><span class="line">ax.set_ylim(Y.<span class="built_in">min</span>(), Y.<span class="built_in">max</span>())</span><br><span class="line">ax.set_zlim(Z.<span class="built_in">min</span>(), Z.<span class="built_in">max</span>())</span><br><span class="line">plt.show()</span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/untitled.png"><br>可以看到，不同种类的图片进入自动编码器之后会被编码得不同，而相同类型的图片经过自动编码之后的编码在几何示意图上距离较近，在训练好自动编码器之后，我们可以给一个随机的 code，通过 decoder 生成图片</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">code = Variable(torch.FloatTensor([[-<span class="number">20.19</span>, <span class="number">10.36</span>, -<span class="number">0.06</span>]])) <span class="comment"># 给一个 code</span></span><br><span class="line">decode = net.decoder(code)</span><br><span class="line">decode_img = to_img(decode).squeeze()</span><br><span class="line">decode_img = decode_img.data.numpy() * <span class="number">255</span></span><br><span class="line">plt.imshow(decode_img.astype(<span class="string">&#x27;uint8&#x27;</span>), cmap=<span class="string">&#x27;gray&#x27;</span>) </span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20201104180838.png"><br>这里我们仅仅使用多层神经网络定义了一个自动编码器，当然你会想到，为什么不使用效果更好的卷积神经网络呢？我们当然可以使用卷积神经网络来定义，下面我们就重新定义一个卷积神经网络来进行 autoencoder</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">conv_autoencoder</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(conv_autoencoder, self).__init__()</span><br><span class="line">        </span><br><span class="line">        self.encoder = nn.Sequential(</span><br><span class="line">            nn.Conv2d(<span class="number">1</span>, <span class="number">16</span>, <span class="number">3</span>, stride=<span class="number">3</span>, padding=<span class="number">1</span>),  <span class="comment"># (b, 16, 10, 10)</span></span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(<span class="number">2</span>, stride=<span class="number">2</span>),  <span class="comment"># (b, 16, 5, 5)</span></span><br><span class="line">            nn.Conv2d(<span class="number">16</span>, <span class="number">8</span>, <span class="number">3</span>, stride=<span class="number">2</span>, padding=<span class="number">1</span>),  <span class="comment"># (b, 8, 3, 3)</span></span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.MaxPool2d(<span class="number">2</span>, stride=<span class="number">1</span>)  <span class="comment"># (b, 8, 2, 2)</span></span><br><span class="line">        )</span><br><span class="line">        </span><br><span class="line">        self.decoder = nn.Sequential(</span><br><span class="line">            nn.ConvTranspose2d(<span class="number">8</span>, <span class="number">16</span>, <span class="number">3</span>, stride=<span class="number">2</span>),  <span class="comment"># (b, 16, 5, 5)</span></span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.ConvTranspose2d(<span class="number">16</span>, <span class="number">8</span>, <span class="number">5</span>, stride=<span class="number">3</span>, padding=<span class="number">1</span>),  <span class="comment"># (b, 8, 15, 15)</span></span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.ConvTranspose2d(<span class="number">8</span>, <span class="number">1</span>, <span class="number">2</span>, stride=<span class="number">2</span>, padding=<span class="number">1</span>),  <span class="comment"># (b, 1, 28, 28)</span></span><br><span class="line">            nn.Tanh()</span><br><span class="line">        )</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        encode = self.encoder(x)</span><br><span class="line">        decode = self.decoder(encode)</span><br><span class="line">        <span class="keyword">return</span> encode, decode</span><br><span class="line"></span><br><span class="line">conv_net = conv_autoencoder()</span><br><span class="line"><span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">    conv_net = conv_net.cuda()</span><br><span class="line">optimizer = torch.optim.Adam(conv_net.parameters(), lr=<span class="number">1e-3</span>, weight_decay=<span class="number">1e-5</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 开始训练自动编码器</span></span><br><span class="line"><span class="keyword">for</span> e <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">40</span>):</span><br><span class="line">    <span class="keyword">for</span> im, _ <span class="keyword">in</span> train_data:</span><br><span class="line">        <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">            im = im.cuda()</span><br><span class="line">            <span class="built_in">print</span>(torch.device(<span class="string">&quot;cuda&quot;</span>))</span><br><span class="line">        im = Variable(im)</span><br><span class="line">        <span class="comment"># 前向传播</span></span><br><span class="line">        _, output = conv_net(im)</span><br><span class="line">        loss = criterion(output, im) / im.shape[<span class="number">0</span>] <span class="comment"># 平均</span></span><br><span class="line">        <span class="comment"># 反向传播</span></span><br><span class="line">        optimizer.zero_grad()</span><br><span class="line">        loss.backward()</span><br><span class="line">        optimizer.step()</span><br><span class="line">    </span><br><span class="line">    <span class="keyword">if</span> (e+<span class="number">1</span>) % <span class="number">20</span> == <span class="number">0</span>: <span class="comment"># 每 20 次，将生成的图片保存一下</span></span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;epoch: &#123;&#125;, Loss: &#123;:.4f&#125;&#x27;</span>.<span class="built_in">format</span>(e+<span class="number">1</span>, loss.data))</span><br><span class="line">        pic = to_img(output.cpu().data)</span><br><span class="line">        <span class="keyword">if</span> <span class="keyword">not</span> os.path.exists(<span class="string">&#x27;./conv_autoencoder&#x27;</span>):</span><br><span class="line">            os.mkdir(<span class="string">&#x27;./conv_autoencoder&#x27;</span>)</span><br><span class="line">        save_image(pic, <span class="string">&#x27;./conv_autoencoder/image_&#123;&#125;.png&#x27;</span>.<span class="built_in">format</span>(e+<span class="number">1</span>))</span><br></pre></td></tr></table></figure><p>为了时间更短，只跑 40 次，如果有条件可以再 gpu 上跑跑.这里我们展示了简单的自动编码器，也用了多层神经网络和卷积神经网络作为例子，但是自动编码器存在一个问题，我们并不能任意生成我们想要的数据，因为我们并不知道 encode 之后的编码到底是什么样的概率分布，所以有一个改进的版本变分自动编码器，其能够解决这个问题</p><h4 id="6-1-2-变分自动编码器"><a href="#6-1-2-变分自动编码器" class="headerlink" title="6.1.2 变分自动编码器"></a>6.1.2 变分自动编码器</h4><p>变分自动编码器（Variational Auto Encoder, VAE）是自动编码器的升级版本，它的结构和自动编码器相似，也是由编码器和解码器构成的。</p><p>自动编码器不能任意生成数据，因为没办法自己去构造隐藏向量，需要通过数据输入编码才知道得到的隐含向量是什么，这个时候变分自动编码器就可以解决这个问题</p><p>它的原理是，在编码过程给他增加一些限制，迫使他生成的隐含向量能够粗略地遵循一个标准正态分布。</p><p>这样我们生成一张新图片就很简单了，我们只需要给它一个标准正态分布的随机隐含向量，这样通过解码器就能够生成我们想要的图片，而不需要给它一张原始图片先编码。</p><p>一般来讲，我们通过 encoder 得到的隐含向量并不是一个标准的正态分布，为了衡量两种分布的相似程度，我们使用 KL divergence，利用其来表示隐含向量与标准正态分布之间差异的 loss，另外一个 loss 仍然使用生成图片与原图片的均方误差来表示。</p><p>KL divergence 的公式如下<br>$$<br>D_{KL} (P || Q) = \sum_{i} p(i) \log \frac{P(i)}{Q(i)}<br>$$</p><p>$$<br>D_{KL} (P || Q) = \int_{-\infty}^{\infty} p(x) \log \frac{p(x)}{q(x)} dx<br>$$</p><p><strong>重参数</strong></p><p>为了避免计算 KL divergence 中的积分，我们使用重参数的技巧，不是每次产生一个隐含向量，而是生成两个向量，一个表示均值，一个表示标准差，这里我们默认编码之后的隐含向量服从一个正态分布的之后，就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布，最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布，也就是希望均值为 0，方差为 1</p><p><a target="_blank" rel="external nofollow noopener noreferrer" href="https://arxiv.org/pdf/1606.05908.pdf">详细内容见https://arxiv.org/pdf/1606.05908.pdf</a></p><p>所以最后我们可以将我们的 loss 定义为下面的函数，由均方误差和 KL divergence 求和得到一个总的 loss</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line">reconstruction_funtion = nn.BCELoss(size_average=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">loss_function</span>(<span class="params">recon_x, x, mu, logvar</span>):</span></span><br><span class="line">    <span class="string">&quot;&quot;&quot;</span></span><br><span class="line"><span class="string">    recon_x: generating images</span></span><br><span class="line"><span class="string">    x: origin images</span></span><br><span class="line"><span class="string">    mu: latent mean</span></span><br><span class="line"><span class="string">    logvar: latent log variance</span></span><br><span class="line"><span class="string">    &quot;&quot;&quot;</span></span><br><span class="line">    MSE = reconstruction_function(recon_x, x)</span><br><span class="line">    <span class="comment"># loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)</span></span><br><span class="line">    KLD_element = mu.<span class="built_in">pow</span>(<span class="number">2</span>).add_(logvar.exp()).mul_(-<span class="number">1</span>).add_(<span class="number">1</span>).add_(logvar)</span><br><span class="line">    KLD = torch.<span class="built_in">sum</span>(KLD_element).mul_(-<span class="number">0.5</span>)</span><br><span class="line">    <span class="comment"># KL divergence</span></span><br><span class="line">    <span class="keyword">return</span> MSE + KLD</span><br></pre></td></tr></table></figure><p>下面我们用 mnist 数据集来简单说明一下变分自动编码器</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> os</span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"><span class="keyword">import</span> torch.nn.functional <span class="keyword">as</span> F</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">from</span> torchvision.datasets <span class="keyword">import</span> MNIST</span><br><span class="line"><span class="keyword">from</span> torchvision <span class="keyword">import</span> transforms <span class="keyword">as</span> tfs</span><br><span class="line"><span class="keyword">from</span> torchvision.utils <span class="keyword">import</span> save_image</span><br><span class="line"></span><br><span class="line">im_tfs = tfs.Compose([</span><br><span class="line">    tfs.ToTensor(),</span><br><span class="line">    tfs.Normalize([<span class="number">0.5</span>], [<span class="number">0.5</span>]) <span class="comment"># 标准化</span></span><br><span class="line">])</span><br><span class="line"></span><br><span class="line">train_set = MNIST(<span class="string">&#x27;./data&#x27;</span>, transform=im_tfs)</span><br><span class="line">train_data = DataLoader(train_set, batch_size=<span class="number">128</span>, shuffle=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">VAE</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(VAE, self).__init__()</span><br><span class="line">        self.fc1 = nn.Linear(<span class="number">784</span>, <span class="number">400</span>)</span><br><span class="line">        self.fc21 = nn.Linear(<span class="number">400</span>, <span class="number">20</span>) <span class="comment"># mean</span></span><br><span class="line">        self.fc22 = nn.Linear(<span class="number">400</span>, <span class="number">20</span>) <span class="comment"># var</span></span><br><span class="line">        self.fc3 = nn.Linear(<span class="number">20</span>, <span class="number">400</span>)</span><br><span class="line">        self.fc4 = nn.Linear(<span class="number">400</span>, <span class="number">784</span>)</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">encode</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        h1 = F.relu(self.fc1(x))</span><br><span class="line">        <span class="keyword">return</span> self.fc21(h1), self.fc22(h1)</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">reparametrize</span>(<span class="params">self, mu, logvar</span>):</span></span><br><span class="line">        std = logvar.mul(<span class="number">0.5</span>).exp_()</span><br><span class="line">        eps = torch.FloatTensor(std.size()).normal_()</span><br><span class="line">        <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">            eps = Variable(eps.cuda())</span><br><span class="line">        <span class="keyword">else</span>:</span><br><span class="line">            eps = Variable(eps)</span><br><span class="line">        <span class="keyword">return</span> eps.mul(std).add_(mu)</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">decode</span>(<span class="params">self, z</span>):</span></span><br><span class="line">        h3 = F.relu(self.fc3(z))</span><br><span class="line">        <span class="keyword">return</span> F.tanh(self.fc4(h3))</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        mu, logvar = self.encode(x) <span class="comment"># 编码</span></span><br><span class="line">        z = self.reparametrize(mu, logvar) <span class="comment"># 重新参数化成正态分布</span></span><br><span class="line">        <span class="keyword">return</span> self.decode(z), mu, logvar <span class="comment"># 解码，同时输出均值方差</span></span><br><span class="line"></span><br><span class="line">net = VAE() <span class="comment"># 实例化网络</span></span><br><span class="line"><span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">    net = net.cuda()</span><br><span class="line"></span><br><span class="line">reconstruction_function = nn.MSELoss(size_average=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">loss_function</span>(<span class="params">recon_x, x, mu, logvar</span>):</span></span><br><span class="line">    <span class="string">&quot;&quot;&quot;</span></span><br><span class="line"><span class="string">    recon_x: generating images</span></span><br><span class="line"><span class="string">    x: origin images</span></span><br><span class="line"><span class="string">    mu: latent mean</span></span><br><span class="line"><span class="string">    logvar: latent log variance</span></span><br><span class="line"><span class="string">    &quot;&quot;&quot;</span></span><br><span class="line">    MSE = reconstruction_function(recon_x, x)</span><br><span class="line">    <span class="comment"># loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)</span></span><br><span class="line">    KLD_element = mu.<span class="built_in">pow</span>(<span class="number">2</span>).add_(logvar.exp()).mul_(-<span class="number">1</span>).add_(<span class="number">1</span>).add_(logvar)</span><br><span class="line">    KLD = torch.<span class="built_in">sum</span>(KLD_element).mul_(-<span class="number">0.5</span>)</span><br><span class="line">    <span class="comment"># KL divergence</span></span><br><span class="line">    <span class="keyword">return</span> MSE + KLD</span><br><span class="line"></span><br><span class="line">optimizer = torch.optim.Adam(net.parameters(), lr=<span class="number">1e-3</span>)</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">to_img</span>(<span class="params">x</span>):</span></span><br><span class="line">    <span class="comment">#定义一个函数将最后的结果转换回图片</span></span><br><span class="line">    x = <span class="number">0.5</span> * (x + <span class="number">1.</span>)</span><br><span class="line">    x = x.clamp(<span class="number">0</span>, <span class="number">1</span>)</span><br><span class="line">    x = x.view(x.shape[<span class="number">0</span>], <span class="number">1</span>, <span class="number">28</span>, <span class="number">28</span>)</span><br><span class="line">    <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> e <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">100</span>):</span><br><span class="line">    <span class="keyword">for</span> im, _ <span class="keyword">in</span> train_data:</span><br><span class="line">        im = im.view(im.shape[<span class="number">0</span>], -<span class="number">1</span>)</span><br><span class="line">        im = Variable(im)</span><br><span class="line">        <span class="keyword">if</span> torch.cuda.is_available():</span><br><span class="line">            im = im.cuda()</span><br><span class="line">            <span class="built_in">print</span>(torch.device(<span class="string">&quot;cuda&quot;</span>))</span><br><span class="line">        recon_im, mu, logvar = net(im)</span><br><span class="line">        loss = loss_function(recon_im, im, mu, logvar) / im.shape[<span class="number">0</span>] <span class="comment"># 将 loss 平均</span></span><br><span class="line">        optimizer.zero_grad()</span><br><span class="line">        loss.backward()</span><br><span class="line">        optimizer.step()</span><br><span class="line"></span><br><span class="line">    <span class="keyword">if</span> (e + <span class="number">1</span>) % <span class="number">20</span> == <span class="number">0</span>:</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;epoch: &#123;&#125;, Loss: &#123;:.4f&#125;&#x27;</span>.<span class="built_in">format</span>(e + <span class="number">1</span>, loss.data[<span class="number">0</span>]))</span><br><span class="line">        save = to_img(recon_im.cpu().data)</span><br><span class="line">        <span class="keyword">if</span> <span class="keyword">not</span> os.path.exists(<span class="string">&#x27;./vae_img&#x27;</span>):</span><br><span class="line">            os.mkdir(<span class="string">&#x27;./vae_img&#x27;</span>)</span><br><span class="line">        save_image(save, <span class="string">&#x27;./vae_img/image_&#123;&#125;.png&#x27;</span>.<span class="built_in">format</span>(e + <span class="number">1</span>))</span><br></pre></td></tr></table></figure><p>可以看看使用变分自动编码器得到的结果，可以发现效果比一般的编码器要好很多</p><h3 id="6-2-生成对抗网络"><a href="#6-2-生成对抗网络" class="headerlink" title="6.2 生成对抗网络"></a>6.2 生成对抗网络</h3><p>前面我们讲了自动编码器和变分自动编码器，不管是哪一个，都是通过计算生成图像和输入图像在每个像素点的误差来生成 loss，这一点是特别不好的，因为不同的像素点可能造成不同的视觉结果，但是可能他们的 loss 是相同的，所以通过单个像素点来得到 loss 是不准确的，这个时候我们需要一种全新的 loss 定义方式，就是通过对抗进行学习。</p><h4 id="6-2-1-什么是生成对抗网络"><a href="#6-2-1-什么是生成对抗网络" class="headerlink" title="6.2.1 什么是生成对抗网络"></a>6.2.1 什么是生成对抗网络</h4><p>这种训练方式定义了一种全新的网络结构，就是生成对抗网络，也就是 GANs。</p><p>根据这个名字就可以知道这个网络是由两部分组成的，第一部分是生成，第二部分是对抗。简单来说，就是有一个生成网络和一个判别网络，通过训练让两个网络相互竞争，生成网络来生成假的数据，对抗网络通过判别器去判别真伪，最后希望生成器生成的数据能够以假乱真。</p><p><strong>对抗：Discriminator Network</strong></p><p>首先我们来讲一下对抗过程，因为这个过程更加简单。</p><p>对抗过程简单来说就是一个判断真假的判别器，相当于一个二分类问题，我们输入一张真的图片希望判别器输出的结果是1，输入一张假的图片希望判别器输出的结果是0。这其实已经和原图片的 label 没有关系了，不管原图片到底是一个多少类别的图片，他们都统一称为真的图片，label 是 1 表示真实的；而生成的假的图片的 label 是 0 表示假的。</p><p>我们训练的过程就是希望这个判别器能够正确的判出真的图片和假的图片，这其实就是一个简单的二分类问题，对于这个问题可以用我们前面讲过的很多方法去处理，比如 logistic 回归，深层网络，卷积神经网络，循环神经网络都可以。</p><p><strong>生成：Generator Network</strong></p><p>接着我们看看生成网络如何生成一张假的图片。首先给出一个简单的高维的正态分布的噪声向量，这个时候我们可以通过仿射变换，也就是 xw+b 将其映射到一个更高的维度，然后将他重新排列成一个矩形，这样看着更像一张图片，接着进行一些卷积、转置卷积、池化、激活函数等进行处理，最后得到了一个与我们输入图片大小一模一样的噪音矩阵，这就是我们所说的假的图片。</p><p>这个时候我们如何去训练这个生成器呢？这就需要通过对抗学习，增大判别器判别这个结果为真的概率，通过这个步骤不断调整生成器的参数，希望生成的图片越来越像真的，而在这一步中我们不会更新判别器的参数，因为如果判别器不断被优化，可能生成器无论生成什么样的图片都无法骗过判别器。</p><p>关于生成对抗网络，出现了很多变形，比如 WGAN，LS-GAN 等等，这里我们只使用 mnist 举一些简单的例子来说明，更复杂的网络结构可以在 github 上找到相应的实现</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> torchvision.transforms <span class="keyword">as</span> tfs</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader, sampler</span><br><span class="line"><span class="keyword">from</span> torchvision.datasets <span class="keyword">import</span> MNIST</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line"><span class="keyword">import</span> matplotlib.gridspec <span class="keyword">as</span> gridspec</span><br><span class="line"></span><br><span class="line">%matplotlib inline</span><br><span class="line">plt.rcParams[<span class="string">&#x27;figure.figsize&#x27;</span>] = (<span class="number">10.0</span>, <span class="number">8.0</span>) <span class="comment"># 设置画图的尺寸</span></span><br><span class="line">plt.rcParams[<span class="string">&#x27;image.interpolation&#x27;</span>] = <span class="string">&#x27;nearest&#x27;</span></span><br><span class="line">plt.rcParams[<span class="string">&#x27;image.cmap&#x27;</span>] = <span class="string">&#x27;gray&#x27;</span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">show_images</span>(<span class="params">images</span>):</span> <span class="comment"># 定义画图工具</span></span><br><span class="line">    images = np.reshape(images, [images.shape[<span class="number">0</span>], -<span class="number">1</span>])</span><br><span class="line">    sqrtn = <span class="built_in">int</span>(np.ceil(np.sqrt(images.shape[<span class="number">0</span>])))</span><br><span class="line">    sqrtimg = <span class="built_in">int</span>(np.ceil(np.sqrt(images.shape[<span class="number">1</span>])))</span><br><span class="line"></span><br><span class="line">    fig = plt.figure(figsize=(sqrtn, sqrtn))</span><br><span class="line">    gs = gridspec.GridSpec(sqrtn, sqrtn)</span><br><span class="line">    gs.update(wspace=<span class="number">0.05</span>, hspace=<span class="number">0.05</span>)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">for</span> i, img <span class="keyword">in</span> <span class="built_in">enumerate</span>(images):</span><br><span class="line">        ax = plt.subplot(gs[i])</span><br><span class="line">        plt.axis(<span class="string">&#x27;off&#x27;</span>)</span><br><span class="line">        ax.set_xticklabels([])</span><br><span class="line">        ax.set_yticklabels([])</span><br><span class="line">        ax.set_aspect(<span class="string">&#x27;equal&#x27;</span>)</span><br><span class="line">        plt.imshow(img.reshape([sqrtimg,sqrtimg]))</span><br><span class="line">    <span class="keyword">return</span> </span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">preprocess_img</span>(<span class="params">x</span>):</span></span><br><span class="line">    x = tfs.ToTensor()(x)</span><br><span class="line">    <span class="keyword">return</span> (x - <span class="number">0.5</span>) / <span class="number">0.5</span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">deprocess_img</span>(<span class="params">x</span>):</span></span><br><span class="line">    <span class="keyword">return</span> (x + <span class="number">1.0</span>) / <span class="number">2.0</span></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">ChunkSampler</span>(<span class="params">sampler.Sampler</span>):</span> <span class="comment"># 定义一个取样的函数</span></span><br><span class="line">    <span class="string">&quot;&quot;&quot;Samples elements sequentially from some offset. </span></span><br><span class="line"><span class="string">    Arguments:</span></span><br><span class="line"><span class="string">        num_samples: # of desired datapoints</span></span><br><span class="line"><span class="string">        start: offset where we should start selecting from</span></span><br><span class="line"><span class="string">    &quot;&quot;&quot;</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, num_samples, start=<span class="number">0</span></span>):</span></span><br><span class="line">        self.num_samples = num_samples</span><br><span class="line">        self.start = start</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__iter__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="keyword">return</span> <span class="built_in">iter</span>(<span class="built_in">range</span>(self.start, self.start + self.num_samples))</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__len__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="keyword">return</span> self.num_samples</span><br><span class="line"></span><br><span class="line">NUM_TRAIN = <span class="number">50000</span></span><br><span class="line">NUM_VAL = <span class="number">5000</span></span><br><span class="line"></span><br><span class="line">NOISE_DIM = <span class="number">96</span></span><br><span class="line">batch_size = <span class="number">128</span></span><br><span class="line"></span><br><span class="line">train_set = MNIST(<span class="string">&#x27;./data&#x27;</span>, train=<span class="literal">True</span>, download=<span class="literal">True</span>, transform=preprocess_img)</span><br><span class="line"></span><br><span class="line">train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, <span class="number">0</span>))</span><br><span class="line"></span><br><span class="line">val_set = MNIST(<span class="string">&#x27;./data&#x27;</span>, train=<span class="literal">True</span>, download=<span class="literal">True</span>, transform=preprocess_img)</span><br><span class="line"></span><br><span class="line">val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))</span><br><span class="line"></span><br><span class="line">imgs = deprocess_img(train_data.__iter__().<span class="built_in">next</span>()[<span class="number">0</span>].view(batch_size, <span class="number">784</span>)).numpy().squeeze() <span class="comment"># 可视化图片效果</span></span><br><span class="line">show_images(imgs)</span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20201105142821.png"></p><p><strong>简单版本的生成对抗网络</strong></p><p>通过前面我们知道生成对抗网络有两个部分构成，一个是生成网络，一个是对抗网络，我们首先写一个简单版本的网络结构，生成网络和对抗网络都是简单的多层神经网络</p><p><strong>判别网络</strong></p><p>判别网络的结构非常简单，就是一个二分类器，结构如下:</p><ul><li>全连接(784 -&gt; 256)</li><li>leakyrelu, $\alpha$ 是 0.2</li><li>全连接(256 -&gt; 256)</li><li>leakyrelu, $\alpha$ 是 0.2</li><li>全连接(256 -&gt; 1)</li></ul><p>其中 leakyrelu 是指 f(x) = max($\alpha$ x, x)</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">discriminator</span>():</span></span><br><span class="line">    net = nn.Sequential(        </span><br><span class="line">            nn.Linear(<span class="number">784</span>, <span class="number">256</span>),</span><br><span class="line">            nn.LeakyReLU(<span class="number">0.2</span>),</span><br><span class="line">            nn.Linear(<span class="number">256</span>, <span class="number">256</span>),</span><br><span class="line">            nn.LeakyReLU(<span class="number">0.2</span>),</span><br><span class="line">            nn.Linear(<span class="number">256</span>, <span class="number">1</span>)</span><br><span class="line">        )</span><br><span class="line">    <span class="keyword">return</span> net</span><br></pre></td></tr></table></figure><p><strong>生成网络</strong></p><p>接下来我们看看生成网络，生成网络的结构也很简单，就是根据一个随机噪声生成一个和数据维度一样的张量，结构如下：</p><ul><li>全连接(噪音维度 -&gt; 1024)</li><li>relu</li><li>全连接(1024 -&gt; 1024)</li><li>relu</li><li>全连接(1024 -&gt; 784)</li><li>tanh 将数据裁剪到 -1 ~ 1 之间</li></ul><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">generator</span>(<span class="params">noise_dim=NOISE_DIM</span>):</span>   </span><br><span class="line">    net = nn.Sequential(</span><br><span class="line">        nn.Linear(noise_dim, <span class="number">1024</span>),</span><br><span class="line">        nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">        nn.Linear(<span class="number">1024</span>, <span class="number">1024</span>),</span><br><span class="line">        nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">        nn.Linear(<span class="number">1024</span>, <span class="number">784</span>),</span><br><span class="line">        nn.Tanh()</span><br><span class="line">    )</span><br><span class="line">    <span class="keyword">return</span> net</span><br></pre></td></tr></table></figure><p>接下来我们需要定义生成对抗网络的 loss，通过前面的讲解我们知道，对于对抗网络，相当于二分类问题，将真的判别为真的，假的判别为假的，作为辅助，可以参考一下论文中公式</p><p>$$ \ell_D = \mathbb{E}_{x \sim p_\text{data}}\left[\log D(x)\right] + \mathbb{E} _ {z \sim p(z)}\left[\log \left(1-D(G(z))\right)\right]$$</p><p>而对于生成网络，需要去骗过对抗网络，也就是将假的也判断为真的，作为辅助，可以参考一下论文中公式</p><p>$$\ell_G = \mathbb{E} _ {z \sim p(z)}\left[\log D(G(z))\right]$$</p><p>如果你还记得前面的二分类 loss，那么你就会发现上面这两个公式就是二分类 loss</p><p>$$ bce(s, y) = y * \log(s) + (1 - y) * \log(1 - s) $$</p><p>如果我们把 D(x) 看成真实数据的分类得分，那么 D(G(z)) 就是假数据的分类得分，所以上面判别器的 loss 就是将真实数据的得分判断为 1，假的数据的得分判断为 0，而生成器的 loss 就是将假的数据判断为 1</p><p>下面我们来实现一下</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> torchvision.transforms <span class="keyword">as</span> tfs</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader, sampler</span><br><span class="line"><span class="keyword">from</span> torchvision.datasets <span class="keyword">import</span> MNIST</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"></span><br><span class="line">NUM_TRAIN = <span class="number">50000</span></span><br><span class="line">NUM_VAL = <span class="number">5000</span></span><br><span class="line"></span><br><span class="line">NOISE_DIM = <span class="number">96</span></span><br><span class="line">batch_size = <span class="number">128</span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">discriminator</span>():</span></span><br><span class="line">    net = nn.Sequential(        </span><br><span class="line">            nn.Linear(<span class="number">784</span>, <span class="number">256</span>),</span><br><span class="line">            nn.LeakyReLU(<span class="number">0.2</span>),</span><br><span class="line">            nn.Linear(<span class="number">256</span>, <span class="number">256</span>),</span><br><span class="line">            nn.LeakyReLU(<span class="number">0.2</span>),</span><br><span class="line">            nn.Linear(<span class="number">256</span>, <span class="number">1</span>)</span><br><span class="line">        )</span><br><span class="line">    <span class="keyword">return</span> net</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">generator</span>(<span class="params">noise_dim=NOISE_DIM</span>):</span>   </span><br><span class="line">    net = nn.Sequential(</span><br><span class="line">        nn.Linear(noise_dim, <span class="number">1024</span>),</span><br><span class="line">        nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">        nn.Linear(<span class="number">1024</span>, <span class="number">1024</span>),</span><br><span class="line">        nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">        nn.Linear(<span class="number">1024</span>, <span class="number">784</span>),</span><br><span class="line">        nn.Tanh()</span><br><span class="line">    )</span><br><span class="line">    <span class="keyword">return</span> net</span><br><span class="line"></span><br><span class="line">bce_loss = nn.BCEWithLogitsLoss()</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">discriminator_loss</span>(<span class="params">logits_real, logits_fake</span>):</span> <span class="comment"># 判别器的 loss</span></span><br><span class="line">    size = logits_real.shape[<span class="number">0</span>]</span><br><span class="line">    true_labels = Variable(torch.ones(size, <span class="number">1</span>)).<span class="built_in">float</span>().cuda()</span><br><span class="line">    false_labels = Variable(torch.zeros(size, <span class="number">1</span>)).<span class="built_in">float</span>().cuda()</span><br><span class="line">    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)</span><br><span class="line">    <span class="keyword">return</span> loss</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">generator_loss</span>(<span class="params">logits_fake</span>):</span> <span class="comment"># 生成器的 loss  </span></span><br><span class="line">    size = logits_fake.shape[<span class="number">0</span>]</span><br><span class="line">    true_labels = Variable(torch.ones(size, <span class="number">1</span>)).<span class="built_in">float</span>().cuda()</span><br><span class="line">    loss = bce_loss(logits_fake, true_labels)</span><br><span class="line">    <span class="keyword">return</span> loss</span><br><span class="line"></span><br><span class="line"><span class="comment"># 使用 adam 来进行训练，学习率是 3e-4, beta1 是 0.5, beta2 是 0.999</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">get_optimizer</span>(<span class="params">net</span>):</span></span><br><span class="line">    optimizer = torch.optim.Adam(net.parameters(), lr=<span class="number">3e-4</span>, betas=(<span class="number">0.5</span>, <span class="number">0.999</span>))</span><br><span class="line">    <span class="keyword">return</span> optimizer</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">preprocess_img</span>(<span class="params">x</span>):</span></span><br><span class="line">    x = tfs.ToTensor()(x)</span><br><span class="line">    <span class="keyword">return</span> (x - <span class="number">0.5</span>) / <span class="number">0.5</span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">deprocess_img</span>(<span class="params">x</span>):</span></span><br><span class="line">    <span class="keyword">return</span> (x + <span class="number">1.0</span>) / <span class="number">2.0</span></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">ChunkSampler</span>(<span class="params">sampler.Sampler</span>):</span> <span class="comment"># 定义一个取样的函数</span></span><br><span class="line">    <span class="string">&quot;&quot;&quot;Samples elements sequentially from some offset. </span></span><br><span class="line"><span class="string">    Arguments:</span></span><br><span class="line"><span class="string">        num_samples: # of desired datapoints</span></span><br><span class="line"><span class="string">        start: offset where we should start selecting from</span></span><br><span class="line"><span class="string">    &quot;&quot;&quot;</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, num_samples, start=<span class="number">0</span></span>):</span></span><br><span class="line">        self.num_samples = num_samples</span><br><span class="line">        self.start = start</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__iter__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="keyword">return</span> <span class="built_in">iter</span>(<span class="built_in">range</span>(self.start, self.start + self.num_samples))</span><br><span class="line"></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__len__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="keyword">return</span> self.num_samples</span><br><span class="line"></span><br><span class="line">train_set = MNIST(<span class="string">&#x27;./data&#x27;</span>, train=<span class="literal">True</span>, download=<span class="literal">True</span>, transform=preprocess_img)</span><br><span class="line"></span><br><span class="line">train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, <span class="number">0</span>))</span><br><span class="line"></span><br><span class="line">val_set = MNIST(<span class="string">&#x27;./data&#x27;</span>, train=<span class="literal">True</span>, download=<span class="literal">True</span>, transform=preprocess_img)</span><br><span class="line"></span><br><span class="line">val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))</span><br><span class="line"></span><br><span class="line"><span class="comment">#下面我们开始训练一个这个简单的生成对抗网络</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">train_a_gan</span>(<span class="params">D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=<span class="number">250</span>, </span></span></span><br><span class="line"><span class="params"><span class="function">                noise_size=<span class="number">96</span>, num_epochs=<span class="number">10</span></span>):</span></span><br><span class="line">    iter_count = <span class="number">0</span></span><br><span class="line">    <span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(num_epochs):</span><br><span class="line">        <span class="keyword">for</span> x, _ <span class="keyword">in</span> train_data:</span><br><span class="line">            bs = x.shape[<span class="number">0</span>]</span><br><span class="line">            <span class="comment"># 判别网络</span></span><br><span class="line">            real_data = Variable(x).view(bs, -<span class="number">1</span>).cuda() <span class="comment"># 真实数据</span></span><br><span class="line">            logits_real = D_net(real_data) <span class="comment"># 判别网络得分</span></span><br><span class="line">            </span><br><span class="line">            sample_noise = (torch.rand(bs, noise_size) - <span class="number">0.5</span>) / <span class="number">0.5</span> <span class="comment"># -1 ~ 1 的均匀分布</span></span><br><span class="line">            g_fake_seed = Variable(sample_noise).cuda()</span><br><span class="line">            fake_images = G_net(g_fake_seed) <span class="comment"># 生成的假的数据</span></span><br><span class="line">            logits_fake = D_net(fake_images) <span class="comment"># 判别网络得分</span></span><br><span class="line"></span><br><span class="line">            d_total_error = discriminator_loss(logits_real, logits_fake) <span class="comment"># 判别器的 loss</span></span><br><span class="line">            D_optimizer.zero_grad()</span><br><span class="line">            d_total_error.backward()</span><br><span class="line">            D_optimizer.step() <span class="comment"># 优化判别网络</span></span><br><span class="line">            </span><br><span class="line">            <span class="comment"># 生成网络</span></span><br><span class="line">            g_fake_seed = Variable(sample_noise).cuda()</span><br><span class="line">            fake_images = G_net(g_fake_seed) <span class="comment"># 生成的假的数据</span></span><br><span class="line"></span><br><span class="line">            gen_logits_fake = D_net(fake_images)</span><br><span class="line">            g_error = generator_loss(gen_logits_fake) <span class="comment"># 生成网络的 loss</span></span><br><span class="line">            G_optimizer.zero_grad()</span><br><span class="line">            g_error.backward()</span><br><span class="line">            G_optimizer.step() <span class="comment"># 优化生成网络</span></span><br><span class="line"></span><br><span class="line">            <span class="keyword">if</span> (iter_count % show_every == <span class="number">0</span>):</span><br><span class="line">                <span class="built_in">print</span>(<span class="string">&#x27;Iter: &#123;&#125;, D: &#123;:.4&#125;, G:&#123;:.4&#125;&#x27;</span>.<span class="built_in">format</span>(iter_count, d_total_error.data, g_error.data))</span><br><span class="line">                imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())</span><br><span class="line">                show_images(imgs_numpy[<span class="number">0</span>:<span class="number">16</span>])</span><br><span class="line">                plt.show()</span><br><span class="line">                <span class="built_in">print</span>()</span><br><span class="line">            iter_count += <span class="number">1</span></span><br></pre></td></tr></table></figure><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">D = discriminator().cuda()</span><br><span class="line">G = generator().cuda()</span><br><span class="line"></span><br><span class="line">D_optim = get_optimizer(D)</span><br><span class="line">G_optim = get_optimizer(G)</span><br><span class="line"></span><br><span class="line">train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)</span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20201105174301-min.gif"></p><p>我们已经完成了一个简单的生成对抗网络，是不是非常容易呢。但是可以看到效果并不是特别好，生成的数字也不是特别完整，因为我们仅仅使用了简单的多层全连接网络。</p><p>除了这种最基本的生成对抗网络之外，还有很多生成对抗网络的变式，有结构上的变式，也有 loss 上的变式，我们先讲一讲其中一种在 loss 上的变式，Least Squares GAN</p><p><strong>Least Squares GAN</strong></p><p><a target="_blank" rel="external nofollow noopener noreferrer" href="https://arxiv.org/abs/1611.04076">Least Squares GAN</a> 比最原始的 GANs 的 loss 更加稳定，通过名字我们也能够看出这种 GAN 是通过最小平方误差来进行估计，而不是通过二分类的损失函数，下面我们看看 loss 的计算公式</p><p>$$\ell_G = \frac{1}{2}\mathbb{E} _ {z \sim p(z)}\left[\left(D(G(z))-1\right)^2\right]$$</p><p>$$ \ell_D = \frac{1}{2}\mathbb{E}_{x \sim p_\text{data}}\left[\left(D(x)-1\right)^2\right] + \frac{1}{2}\mathbb{E} _ {z \sim p(z)}\left[ \left(D(G(z))\right)^2\right]$$</p><p>可以看到 Least Squares GAN 通过最小二乘代替了二分类的 loss，下面我们定义一下 loss 函数</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">ls_discriminator_loss</span>(<span class="params">scores_real, scores_fake</span>):</span></span><br><span class="line">    loss = <span class="number">0.5</span> * ((scores_real - <span class="number">1</span>) ** <span class="number">2</span>).mean() + <span class="number">0.5</span> * (scores_fake ** <span class="number">2</span>).mean()</span><br><span class="line">    <span class="keyword">return</span> loss</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">ls_generator_loss</span>(<span class="params">scores_fake</span>):</span></span><br><span class="line">    loss = <span class="number">0.5</span> * ((scores_fake - <span class="number">1</span>) ** <span class="number">2</span>).mean()</span><br><span class="line">    <span class="keyword">return</span> loss</span><br></pre></td></tr></table></figure><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">D = discriminator().cuda()</span><br><span class="line">G = generator().cuda()</span><br><span class="line"></span><br><span class="line">D_optim = get_optimizer(D)</span><br><span class="line">G_optim = get_optimizer(G)</span><br><span class="line"></span><br><span class="line">train_a_gan(D, G, D_optim, G_optim, ls_discriminator_loss, ls_generator_loss)</span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20201105174302-min.gif"></p><p>上面我们讲了 最基本的 GAN 和 least squares GAN，最后我们讲一讲使用卷积网络的 GAN，叫做深度卷积生成对抗网络</p><p><strong>Deep Convolutional GANs</strong></p><p>深度卷积生成对抗网络特别简单，就是将生成网络和对抗网络都改成了卷积网络的形式，下面我们来实现一下</p><p>卷积判别网络就是一个一般的卷积网络，结构如下</p><ul><li>32 Filters, 5x5, Stride 1, Leaky ReLU(alpha=0.01)</li><li>Max Pool 2x2, Stride 2</li><li>64 Filters, 5x5, Stride 1, Leaky ReLU(alpha=0.01)</li><li>Max Pool 2x2, Stride 2</li><li>Fully Connected size 4 x 4 x 64, Leaky ReLU(alpha=0.01)</li><li>Fully Connected size 1</li></ul><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">build_dc_classifier</span>(<span class="params">nn.Module</span>):</span></span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(build_dc_classifier, self).__init__()</span><br><span class="line">        self.conv = nn.Sequential(</span><br><span class="line">            nn.Conv2d(<span class="number">1</span>, <span class="number">32</span>, <span class="number">5</span>, <span class="number">1</span>),</span><br><span class="line">            nn.LeakyReLU(<span class="number">0.01</span>),</span><br><span class="line">            nn.MaxPool2d(<span class="number">2</span>, <span class="number">2</span>),</span><br><span class="line">            nn.Conv2d(<span class="number">32</span>, <span class="number">64</span>, <span class="number">5</span>, <span class="number">1</span>),</span><br><span class="line">            nn.LeakyReLU(<span class="number">0.01</span>),</span><br><span class="line">            nn.MaxPool2d(<span class="number">2</span>, <span class="number">2</span>)</span><br><span class="line">        )</span><br><span class="line">        self.fc = nn.Sequential(</span><br><span class="line">            nn.Linear(<span class="number">1024</span>, <span class="number">1024</span>),</span><br><span class="line">            nn.LeakyReLU(<span class="number">0.01</span>),</span><br><span class="line">            nn.Linear(<span class="number">1024</span>, <span class="number">1</span>)</span><br><span class="line">        )</span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        x = self.conv(x)</span><br><span class="line">        x = x.view(x.shape[<span class="number">0</span>], -<span class="number">1</span>)</span><br><span class="line">        x = self.fc(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br></pre></td></tr></table></figure><p>卷积生成网络需要将一个低维的噪声向量变成一个图片数据，结构如下</p><ul><li>Fully connected of size 1024, ReLU</li><li>BatchNorm</li><li>Fully connected of size 7 x 7 x 128, ReLU</li><li>BatchNorm</li><li>Reshape into Image Tensor</li><li>64 conv2d^T filters of 4x4, stride 2, padding 1, ReLU</li><li>BatchNorm</li><li>1 conv2d^T filter of 4x4, stride 2, padding 1, TanH</li></ul><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">build_dc_generator</span>(<span class="params">nn.Module</span>):</span> </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, noise_dim=NOISE_DIM</span>):</span></span><br><span class="line">        <span class="built_in">super</span>(build_dc_generator, self).__init__()</span><br><span class="line">        self.fc = nn.Sequential(</span><br><span class="line">            nn.Linear(noise_dim, <span class="number">1024</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.BatchNorm1d(<span class="number">1024</span>),</span><br><span class="line">            nn.Linear(<span class="number">1024</span>, <span class="number">7</span> * <span class="number">7</span> * <span class="number">128</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.BatchNorm1d(<span class="number">7</span> * <span class="number">7</span> * <span class="number">128</span>)</span><br><span class="line">        )</span><br><span class="line">        </span><br><span class="line">        self.conv = nn.Sequential(</span><br><span class="line">            nn.ConvTranspose2d(<span class="number">128</span>, <span class="number">64</span>, <span class="number">4</span>, <span class="number">2</span>, padding=<span class="number">1</span>),</span><br><span class="line">            nn.ReLU(<span class="literal">True</span>),</span><br><span class="line">            nn.BatchNorm2d(<span class="number">64</span>),</span><br><span class="line">            nn.ConvTranspose2d(<span class="number">64</span>, <span class="number">1</span>, <span class="number">4</span>, <span class="number">2</span>, padding=<span class="number">1</span>),</span><br><span class="line">            nn.Tanh()</span><br><span class="line">        )</span><br><span class="line">        </span><br><span class="line">    <span class="function"><span class="keyword">def</span> <span class="title">forward</span>(<span class="params">self, x</span>):</span></span><br><span class="line">        x = self.fc(x)</span><br><span class="line">        x = x.view(x.shape[<span class="number">0</span>], <span class="number">128</span>, <span class="number">7</span>, <span class="number">7</span>) <span class="comment"># reshape 通道是 128，大小是 7x7</span></span><br><span class="line">        x = self.conv(x)</span><br><span class="line">        <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">train_dc_gan</span>(<span class="params">D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=<span class="number">250</span>, </span></span></span><br><span class="line"><span class="params"><span class="function">                noise_size=<span class="number">96</span>, num_epochs=<span class="number">10</span></span>):</span></span><br><span class="line">    iter_count = <span class="number">0</span></span><br><span class="line">    <span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(num_epochs):</span><br><span class="line">        <span class="keyword">for</span> x, _ <span class="keyword">in</span> train_data:</span><br><span class="line">            bs = x.shape[<span class="number">0</span>]</span><br><span class="line">            <span class="comment"># 判别网络</span></span><br><span class="line">            real_data = Variable(x).cuda() <span class="comment"># 真实数据</span></span><br><span class="line">            logits_real = D_net(real_data) <span class="comment"># 判别网络得分</span></span><br><span class="line">            </span><br><span class="line">            sample_noise = (torch.rand(bs, noise_size) - <span class="number">0.5</span>) / <span class="number">0.5</span> <span class="comment"># -1 ~ 1 的均匀分布</span></span><br><span class="line">            g_fake_seed = Variable(sample_noise).cuda()</span><br><span class="line">            fake_images = G_net(g_fake_seed) <span class="comment"># 生成的假的数据</span></span><br><span class="line">            logits_fake = D_net(fake_images) <span class="comment"># 判别网络得分</span></span><br><span class="line"></span><br><span class="line">            d_total_error = discriminator_loss(logits_real, logits_fake) <span class="comment"># 判别器的 loss</span></span><br><span class="line">            D_optimizer.zero_grad()</span><br><span class="line">            d_total_error.backward()</span><br><span class="line">            D_optimizer.step() <span class="comment"># 优化判别网络</span></span><br><span class="line">            </span><br><span class="line">            <span class="comment"># 生成网络</span></span><br><span class="line">            g_fake_seed = Variable(sample_noise).cuda()</span><br><span class="line">            fake_images = G_net(g_fake_seed) <span class="comment"># 生成的假的数据</span></span><br><span class="line"></span><br><span class="line">            gen_logits_fake = D_net(fake_images)</span><br><span class="line">            g_error = generator_loss(gen_logits_fake) <span class="comment"># 生成网络的 loss</span></span><br><span class="line">            G_optimizer.zero_grad()</span><br><span class="line">            g_error.backward()</span><br><span class="line">            G_optimizer.step() <span class="comment"># 优化生成网络</span></span><br><span class="line"></span><br><span class="line">            <span class="keyword">if</span> (iter_count % show_every == <span class="number">0</span>):</span><br><span class="line">                <span class="built_in">print</span>(<span class="string">&#x27;Iter: &#123;&#125;, D: &#123;:.4&#125;, G:&#123;:.4&#125;&#x27;</span>.<span class="built_in">format</span>(iter_count, d_total_error.data, g_error.data))</span><br><span class="line">                imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())</span><br><span class="line">                show_images(imgs_numpy[<span class="number">0</span>:<span class="number">16</span>])</span><br><span class="line">                plt.show()</span><br><span class="line">                <span class="built_in">print</span>()</span><br><span class="line">            iter_count += <span class="number">1</span></span><br><span class="line"></span><br><span class="line">D_DC = build_dc_classifier().cuda()</span><br><span class="line">G_DC = build_dc_generator().cuda()</span><br><span class="line"></span><br><span class="line">D_DC_optim = get_optimizer(D_DC)</span><br><span class="line">G_DC_optim = get_optimizer(G_DC)</span><br><span class="line"></span><br><span class="line">train_dc_gan(D_DC, G_DC, D_DC_optim, G_DC_optim, discriminator_loss, generator_loss, num_epochs=<span class="number">5</span>)</span><br></pre></td></tr></table></figure><p><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20201105174303-min.gif"><br>可以看到，通过 DCGANs 能够得到更加清楚的结果</p><h3 id="6-3-Improving-GAN"><a href="#6-3-Improving-GAN" class="headerlink" title="6.3 Improving GAN"></a>6.3 Improving GAN</h3><h4 id="6-3-1-Wasserstein-GAN"><a href="#6-3-1-Wasserstein-GAN" class="headerlink" title="6.3.1 Wasserstein GAN"></a>6.3.1 Wasserstein GAN</h4><p>Wasserstein GAN是GAN的一种变式，WGAN的出现解决了下面这些难点</p><ul><li>彻底解决了训练不稳定的问题</li><li>基本解决了coolapse mode 的问题，确保了生成样本的多样性</li><li>训练中有一个向交叉熵，准确率的数值指标来衡量训练的进程，数值越小代表GAN训练得越好，同时也代表着生成的图片质量越高</li><li>不需要精心设计网络结构也能取得较好的效果</li></ul><h3 id="6-4-应用介绍"><a href="#6-4-应用介绍" class="headerlink" title="6.4 应用介绍"></a>6.4 应用介绍</h3><h4 id="6-4-1-Conditional-GAN"><a href="#6-4-1-Conditional-GAN" class="headerlink" title="6.4.1 Conditional GAN"></a>6.4.1 Conditional GAN</h4><p>Conditional GAN的一个应用是文字生成图片</p><h4 id="6-4-2-Cycle-GAN"><a href="#6-4-2-Cycle-GAN" class="headerlink" title="6.4.2 Cycle GAN"></a>6.4.2 Cycle GAN</h4><p>根据一个人的作品，想象他完成其他场景会是什么样</p><h2 id="第七章-深度学习实战"><a href="#第七章-深度学习实战" class="headerlink" title="第七章 深度学习实战"></a>第七章 深度学习实战</h2><h3 id="7-1-实例一，猫狗大战：运用预训练卷积神经网络进行特征提取与预训"><a href="#7-1-实例一，猫狗大战：运用预训练卷积神经网络进行特征提取与预训" class="headerlink" title="7.1 实例一，猫狗大战：运用预训练卷积神经网络进行特征提取与预训"></a>7.1 实例一，猫狗大战：运用预训练卷积神经网络进行特征提取与预训</h3><h4 id="7-1-1-背景介绍"><a href="#7-1-1-背景介绍" class="headerlink" title="7.1.1 背景介绍"></a>7.1.1 背景介绍</h4><p>Asirra是一个图像识别机制的验证码，其有很多不同猫狗的照片（三百万张），可以用他的子集当作训练集</p><h4 id="7-1-2-原理分析"><a href="#7-1-2-原理分析" class="headerlink" title="7.1.2 原理分析"></a>7.1.2 原理分析</h4><p>对于这个问题，简单的网络模型可能效果并不好，这个时候，使用一些成熟的模型，比如VggNet，GoogleNet，ResNet等可以帮助我们解决问题，为了节省计算资源和时间，可以通过迁移学习实现。</p><p><strong>迁移学习</strong></p><p>对于一个特定任务，如果没有来自该任务足够的数据集，传统的监督学习无法支持，而迁移学习允许通过借用已经存在的一些相关任务的标签数据来处理这些场景，把解决相关任务时获得的知识存储下来，并将它应用到我们感兴趣的目标任务中。</p><p>卷积神经网络可以理解为两个部分：前面的<strong>卷积</strong>部分和后面的<strong>分类</strong>部分，卷积部分主要用于提取图片特征，而预训练的网络对于特征提取效果已经非常好。我们可以直接用预训练的网络卷积部分来提取我们自己的图片特征，而对于自己的任务，比如猫狗二分类，就用自己的分类全连接层即可。</p><p>当然，迁移学习并不是任何时候都能使用，需要它们<strong>完成的任务是相关的</strong>，所以迁移学习在相似数据集上的应用效果才是良好的。</p><p><strong>实现方法</strong></p><ol><li>第一种方法：导入预训练的卷积网络，将最后的全连接层改成我们自己设计的全连接层，然后更新整个网络，最后能特别快地达到收敛</li><li>第二种方法：锁定前面卷积层的参数，让网络训练只更新最后全连接层的参数，可以使训练时间大大减少</li><li>第三种方法：使用多个预训练好的网络，将它们并联在一起，图片经过每个网络都会得到特征图，我们将这些特征图拼接在一起进入最后的全连接层</li></ol><h4 id="7-1-3-代码实现"><a href="#7-1-3-代码实现" class="headerlink" title="7.1.3 代码实现"></a>7.1.3 代码实现</h4><p>1.数据预处理</p><p>数据集可以去 <a target="_blank" rel="external nofollow noopener noreferrer" href="https://www.kaggle.com/c/dogs-vs-cats/data">https://www.kaggle.com/c/dogs-vs-cats/data</a> 下载</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> os</span><br><span class="line"><span class="keyword">import</span> shutil</span><br><span class="line"></span><br><span class="line">train_root = <span class="string">&#x27;./data/dogs-vs-cats/train/&#x27;</span></span><br><span class="line">val_root = <span class="string">&#x27;./data/dogs-vs-cats/val/&#x27;</span></span><br><span class="line">data_file=os.listdir(train_root)</span><br><span class="line"><span class="comment">#print(data_file)</span></span><br><span class="line">dog_file = <span class="built_in">list</span>(<span class="built_in">filter</span>(<span class="keyword">lambda</span> x:x.split(<span class="string">&quot;.&quot;</span>)[<span class="number">0</span>]==<span class="string">&#x27;dog&#x27;</span><span class="keyword">and</span> x!=<span class="string">&quot;dog&quot;</span>,data_file))</span><br><span class="line">cat_file = <span class="built_in">list</span>(<span class="built_in">filter</span>(<span class="keyword">lambda</span> x:x.split(<span class="string">&quot;.&quot;</span>)[<span class="number">0</span>]==<span class="string">&#x27;cat&#x27;</span><span class="keyword">and</span> x!=<span class="string">&quot;cat&quot;</span>,data_file))</span><br><span class="line"></span><br><span class="line">root = <span class="string">&#x27;./data/dogs-vs-cats/&#x27;</span></span><br><span class="line"><span class="keyword">if</span> <span class="keyword">not</span> os.path.exists(train_root+<span class="string">&#x27;dog/&#x27;</span>):</span><br><span class="line">    os.makedirs(train_root+<span class="string">&#x27;dog/&#x27;</span>)</span><br><span class="line"><span class="keyword">if</span> <span class="keyword">not</span> os.path.exists(train_root+<span class="string">&#x27;cat/&#x27;</span>):</span><br><span class="line">    os.makedirs(train_root+<span class="string">&#x27;cat/&#x27;</span>)</span><br><span class="line"><span class="keyword">if</span> <span class="keyword">not</span> os.path.exists(val_root+<span class="string">&#x27;dog/&#x27;</span>):</span><br><span class="line">    os.makedirs(val_root+<span class="string">&#x27;dog/&#x27;</span>)</span><br><span class="line"><span class="keyword">if</span> <span class="keyword">not</span> os.path.exists(val_root+<span class="string">&#x27;cat/&#x27;</span>):</span><br><span class="line">    os.makedirs(val_root+<span class="string">&#x27;cat/&#x27;</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(dog_file)):</span><br><span class="line">    pic_path = root+<span class="string">&#x27;train/&#x27;</span>+dog_file[i]</span><br><span class="line">    <span class="keyword">if</span> i &lt; <span class="built_in">len</span>(dog_file)*<span class="number">0.9</span>:</span><br><span class="line">        obj_path = train_root+<span class="string">&#x27;dog/&#x27;</span>+dog_file[i]</span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        obj_path = val_root+<span class="string">&#x27;dog/&#x27;</span>+dog_file[i]</span><br><span class="line">    shutil.move(pic_path,obj_path)</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(cat_file)):</span><br><span class="line">    pic_path = root+<span class="string">&#x27;train/&#x27;</span>+cat_file[i]</span><br><span class="line">    <span class="keyword">if</span> i &lt; <span class="built_in">len</span>(cat_file)*<span class="number">0.9</span>:</span><br><span class="line">        obj_path = train_root+<span class="string">&#x27;cat/&#x27;</span>+cat_file[i]</span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        obj_path = val_root+<span class="string">&#x27;cat/&#x27;</span>+cat_file[i]</span><br><span class="line">    shutil.move(pic_path,obj_path)</span><br></pre></td></tr></table></figure><p>上面的操作实现了，将猫狗照片分别移动到训练集和验证集，其中90%的数据作为训练集，10%的图片作为验证集，使用<code>shutil.move()</code>来移动图片</p><p>2.迁移学习模型训练</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torch.nn <span class="keyword">as</span> nn</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"><span class="keyword">from</span> torchvision <span class="keyword">import</span> models,transforms,datasets</span><br><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> DataLoader</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line"><span class="keyword">import</span> time</span><br><span class="line"></span><br><span class="line">img_classes=<span class="number">2</span></span><br><span class="line">epoch_num = <span class="number">2</span></span><br><span class="line">path = <span class="string">&quot;./data/dogs-vs-cats/&quot;</span></span><br><span class="line"></span><br><span class="line"><span class="comment">#数据</span></span><br><span class="line">data_transform = transforms.Compose([</span><br><span class="line">    transforms.CenterCrop(<span class="number">224</span>),</span><br><span class="line">    transforms.ToTensor(),</span><br><span class="line">    transforms.Normalize([<span class="number">0.5</span>, <span class="number">0.5</span>, <span class="number">0.5</span>],[<span class="number">0.5</span>, <span class="number">0.5</span>, <span class="number">0.5</span>])</span><br><span class="line">])</span><br><span class="line"></span><br><span class="line"><span class="comment"># ImageFOLDER 返回的是一个list，这里的写法是字典的形式</span></span><br><span class="line">data_image = &#123;x: datasets.ImageFolder(root=os.path.join(path, x),transform=data_transform) <span class="keyword">for</span> x <span class="keyword">in</span> [<span class="string">&quot;train&quot;</span>, <span class="string">&quot;val&quot;</span>]&#125;</span><br><span class="line">data_loader_image = &#123;x: DataLoader(dataset=data_image[x],batch_size=<span class="number">4</span>,shuffle=<span class="literal">True</span>) <span class="keyword">for</span> x <span class="keyword">in</span> [<span class="string">&quot;train&quot;</span>, <span class="string">&quot;val&quot;</span>]&#125;</span><br><span class="line"></span><br><span class="line"><span class="comment"># 分类</span></span><br><span class="line">classes = data_image[<span class="string">&quot;train&quot;</span>].classes <span class="comment"># 按文件夹名字分类</span></span><br><span class="line">classes_index = data_image[<span class="string">&quot;train&quot;</span>].class_to_idx <span class="comment"># 文件夹类名所对应的链值</span></span><br><span class="line"><span class="comment"># 打印类别</span></span><br><span class="line"><span class="built_in">print</span>(classes) </span><br><span class="line"><span class="built_in">print</span>(classes_index)</span><br><span class="line"><span class="comment"># 打印训练集，验证集大小</span></span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;train data set:&quot;</span>, <span class="built_in">len</span>(data_image[<span class="string">&quot;train&quot;</span>]))</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;val data set:&quot;</span>, <span class="built_in">len</span>(data_image[<span class="string">&quot;val&quot;</span>]))</span><br><span class="line"></span><br><span class="line"><span class="comment">#导入预训练的网络，并修改全连接层</span></span><br><span class="line">model = models.resnet18(pretrained=<span class="literal">True</span>) <span class="comment"># 18层的残差网络</span></span><br><span class="line"><span class="comment">#print(model)</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> parma <span class="keyword">in</span> model.parameters():</span><br><span class="line">    parma.requires_grad = <span class="literal">False</span>  <span class="comment"># 不进行梯度更新</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 改变模型的全连接层，本项目只需要输出2类</span></span><br><span class="line">model.fc = nn.Sequential(nn.Linear(<span class="number">512</span>, <span class="number">256</span>),</span><br><span class="line">                                       nn.ReLU(),</span><br><span class="line">                                       nn.Dropout(p=<span class="number">0.5</span>),</span><br><span class="line">                                       nn.Linear(<span class="number">256</span>, <span class="number">256</span>),</span><br><span class="line">                                       nn.ReLU(),</span><br><span class="line">                                       nn.Dropout(p=<span class="number">0.5</span>),</span><br><span class="line">                                       nn.Linear(<span class="number">256</span>, <span class="number">2</span>))</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> index, parma <span class="keyword">in</span> <span class="built_in">enumerate</span>(model.fc.parameters()):</span><br><span class="line">    parma.requires_grad = <span class="literal">True</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 是否有GPU</span></span><br><span class="line">use_gpu = torch.cuda.is_available()</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;Find GPU: &quot;</span>,use_gpu)</span><br><span class="line"><span class="keyword">if</span> use_gpu:</span><br><span class="line">    model = model.cuda()</span><br><span class="line"><span class="comment">#print(model)</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 定义代价函数</span></span><br><span class="line">cost = torch.nn.CrossEntropyLoss()</span><br><span class="line"><span class="comment"># 定义优化器</span></span><br><span class="line">optimizer = torch.optim.Adam(model.fc.parameters(),lr=<span class="number">1e-4</span>)</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">train</span>():</span></span><br><span class="line">    <span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(epoch_num):</span><br><span class="line">        since = time.time()</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&quot;Epoch&#123;&#125;/&#123;&#125;&quot;</span>.<span class="built_in">format</span>(epoch+<span class="number">1</span>, epoch_num))</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&quot;-&quot;</span> * <span class="number">10</span>)</span><br><span class="line">        <span class="keyword">for</span> param <span class="keyword">in</span> [<span class="string">&quot;train&quot;</span>, <span class="string">&quot;val&quot;</span>]:</span><br><span class="line">            <span class="keyword">if</span> param == <span class="string">&quot;train&quot;</span>:</span><br><span class="line">                model.train = <span class="literal">True</span></span><br><span class="line">            <span class="keyword">else</span>:</span><br><span class="line">                model.train = <span class="literal">False</span></span><br><span class="line"></span><br><span class="line">            running_loss = <span class="number">0.0</span></span><br><span class="line">            running_correct = <span class="number">0</span></span><br><span class="line">            batch = <span class="number">0</span></span><br><span class="line">            <span class="keyword">for</span> data <span class="keyword">in</span> data_loader_image[param]:</span><br><span class="line">                batch += <span class="number">1</span></span><br><span class="line">                X, y = data</span><br><span class="line">                <span class="keyword">if</span> use_gpu:</span><br><span class="line">                    X, y = Variable(X.cuda()), Variable(y.cuda())</span><br><span class="line">                <span class="keyword">else</span>:</span><br><span class="line">                    X, y = Variable(X), Variable(y)</span><br><span class="line"></span><br><span class="line">                optimizer.zero_grad()</span><br><span class="line">                y_pred = model(X)</span><br><span class="line">                _, pred = torch.<span class="built_in">max</span>(y_pred.data, <span class="number">1</span>)</span><br><span class="line">                loss = cost(y_pred,y)</span><br><span class="line">                <span class="keyword">if</span> param == <span class="string">&quot;train&quot;</span>:</span><br><span class="line">                    loss.backward()</span><br><span class="line">                    optimizer.step()</span><br><span class="line">                running_loss += loss.item()</span><br><span class="line">                <span class="comment"># running_loss += loss.data</span></span><br><span class="line">                running_correct += torch.<span class="built_in">sum</span>(pred == y.data)</span><br><span class="line">                <span class="keyword">if</span> batch % <span class="number">5</span> == <span class="number">0</span> <span class="keyword">and</span> param == <span class="string">&quot;train&quot;</span>:</span><br><span class="line">                    <span class="built_in">print</span>(<span class="string">&quot;Batch &#123;&#125;, Train Loss:&#123;:.4f&#125;, Train ACC:&#123;:.4f&#125;&quot;</span>.<span class="built_in">format</span>(</span><br><span class="line">                        batch, running_loss / (<span class="number">4</span> * batch), <span class="number">100</span> * running_correct / (<span class="number">4</span> * batch)))</span><br><span class="line"></span><br><span class="line">            epoch_loss = running_loss / <span class="built_in">len</span>(data_image[param])</span><br><span class="line">            epoch_correct = <span class="number">100</span> * running_correct / <span class="built_in">len</span>(data_image[param])</span><br><span class="line"></span><br><span class="line">            <span class="built_in">print</span>(<span class="string">&quot;&#123;&#125; Loss:&#123;:.4f&#125;, Correct:&#123;:.4f&#125;&quot;</span>.<span class="built_in">format</span>(param, epoch_loss, epoch_correct))</span><br><span class="line">        now_time = time.time() - since</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&quot;Training time is:&#123;:.0f&#125;m &#123;:.0f&#125;s&quot;</span>.<span class="built_in">format</span>(now_time // <span class="number">60</span>, now_time % <span class="number">60</span>))</span><br><span class="line"></span><br><span class="line">train()</span><br><span class="line">torch.save(model, <span class="string">&#x27;dogsvscats.pth&#x27;</span>)</span><br></pre></td></tr></table></figure><p>测试</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> os</span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torchvision <span class="keyword">import</span> datasets, transforms, models</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"><span class="keyword">import</span> time</span><br><span class="line">model = torch.load(<span class="string">&#x27;dogsvscats.pth&#x27;</span>)</span><br><span class="line">path = <span class="string">&quot;./data/dogs-vs-cats&quot;</span></span><br><span class="line"></span><br><span class="line">transform = transforms.Compose([transforms.CenterCrop(<span class="number">224</span>),</span><br><span class="line">                                transforms.ToTensor(),</span><br><span class="line">                                transforms.Normalize([<span class="number">0.5</span>, <span class="number">0.5</span>, <span class="number">0.5</span>], [<span class="number">0.5</span>, <span class="number">0.5</span>, <span class="number">0.5</span>])])</span><br><span class="line"></span><br><span class="line">data_test_img = datasets.ImageFolder(root=path+<span class="string">&quot;/val/&quot;</span>, transform = transform) </span><br><span class="line"></span><br><span class="line">data_loader_test_img = torch.utils.data.DataLoader(dataset=data_test_img,</span><br><span class="line">                                                  batch_size = <span class="number">16</span>,shuffle=<span class="literal">True</span>) <span class="comment">#载入测试数据集，并随机打乱</span></span><br><span class="line">classes = data_test_img.classes   <span class="comment">##class</span></span><br><span class="line"></span><br><span class="line">image, label = <span class="built_in">next</span>(<span class="built_in">iter</span>(data_loader_test_img))</span><br><span class="line">images = Variable(image).cuda()</span><br><span class="line">y_pred = model(images)</span><br><span class="line">_,pred = torch.<span class="built_in">max</span>(y_pred.data, <span class="number">1</span>)</span><br><span class="line"><span class="built_in">print</span>(pred)</span><br><span class="line"><span class="built_in">print</span>(label)</span><br><span class="line"></span><br><span class="line">img = torchvision.utils.make_grid(image)</span><br><span class="line">img = img.numpy().transpose(<span class="number">1</span>,<span class="number">2</span>,<span class="number">0</span>)</span><br><span class="line">mean = [<span class="number">0.5</span>, <span class="number">0.5</span>, <span class="number">0.5</span>]</span><br><span class="line">std = [<span class="number">0.5</span>, <span class="number">0.5</span>, <span class="number">0.5</span>]</span><br><span class="line">img = img * std + mean</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;Pred Label:&quot;</span>, [classes[i] <span class="keyword">for</span> i <span class="keyword">in</span> pred])</span><br><span class="line">plt.imshow(img)</span><br><span class="line">plt.show()</span><br></pre></td></tr></table></figure><h3 id="7-2-实例二，Deep-Dream：探索卷积神经网络眼中的世界"><a href="#7-2-实例二，Deep-Dream：探索卷积神经网络眼中的世界" class="headerlink" title="7.2 实例二，Deep Dream：探索卷积神经网络眼中的世界"></a>7.2 实例二，Deep Dream：探索卷积神经网络眼中的世界</h3><p>2015年，Google发布了一个有意思的东西，叫做Deep Dream</p><h4 id="7-2-1-原理介绍"><a href="#7-2-1-原理介绍" class="headerlink" title="7.2.1 原理介绍"></a>7.2.1 原理介绍</h4><p><strong>1.反向神经网络</strong></p><p>我们知道经过训练之后，每一层网络足部提取越来越高级的图像特征，知道最后一层将这些特征比较做出分类的结果。比如前面几层也许在寻找边缘和拐角的特征，中间几层分析整体的轮廓特征，这样不断的增加层数就可以发展出越来越多的复杂特征，最后几层将这些特征要素组合起来形成完整的解释，这样到最后网络就会对非常复杂的东西，比如小猫，树叶等图片有所反应</p><p><strong>2.Deep Dream</strong></p><p>如果我们将算法反复地应用到自身的输出上，不断迭代，并在每次迭代后应用一些缩放，就能不断地激活特征，得到无尽的新效果。</p><h4 id="7-2-2-代码实现"><a href="#7-2-2-代码实现" class="headerlink" title="7.2.2 代码实现"></a>7.2.2 代码实现</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torch.nn <span class="keyword">as</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"><span class="keyword">from</span> torchvision <span class="keyword">import</span> models</span><br><span class="line"><span class="keyword">from</span> torchvision <span class="keyword">import</span> transforms, utils</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line">%matplotlib inline</span><br><span class="line"><span class="comment"># PIL.ImageFilter是Python中的图像滤波，主要对图像进行平滑、锐化、边界增强等滤波处理</span></span><br><span class="line"><span class="comment"># PIL.ImageChops模块包含一些算术图形操作，叫做channel operations（“chops”）。这些操作可用于诸多目的，比如图像特效，图像组合，算法绘图等等</span></span><br><span class="line"><span class="keyword">from</span> PIL <span class="keyword">import</span> Image, ImageFilter, ImageChops</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 加载图像并显示</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">load_image</span>(<span class="params">path</span>):</span></span><br><span class="line">    image = Image.<span class="built_in">open</span>(path)</span><br><span class="line">    plt.imshow(image)</span><br><span class="line">    plt.title(<span class="string">&quot;Image loaded successfully&quot;</span>)</span><br><span class="line">    <span class="keyword">return</span> image</span><br><span class="line"></span><br><span class="line"><span class="comment"># 对数据集的标准化设置——减去均值再除以标准差</span></span><br><span class="line">normalise = transforms.Normalize(</span><br><span class="line">    mean=[<span class="number">0.485</span>, <span class="number">0.456</span>, <span class="number">0.406</span>],</span><br><span class="line">    std=[<span class="number">0.229</span>, <span class="number">0.224</span>, <span class="number">0.225</span>]</span><br><span class="line">    )</span><br><span class="line"></span><br><span class="line"><span class="comment"># 数据集的预处理，包括缩放、转换成Tensor、标准化</span></span><br><span class="line">preprocess = transforms.Compose([</span><br><span class="line">    transforms.Resize((<span class="number">224</span>,<span class="number">224</span>)),</span><br><span class="line">    transforms.ToTensor(),</span><br><span class="line">    normalise</span><br><span class="line">    ])</span><br><span class="line"></span><br><span class="line"><span class="comment"># 逆向处理过程，逆标准化，图像乘以标准差再加上均值</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">deprocess</span>(<span class="params">image</span>):</span></span><br><span class="line">    <span class="keyword">return</span> image * torch.Tensor([<span class="number">0.229</span>, <span class="number">0.224</span>, <span class="number">0.225</span>]).cuda()  + torch.Tensor([<span class="number">0.485</span>, <span class="number">0.456</span>, <span class="number">0.406</span>]).cuda()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 下载vgg16的预训练模型，传到GPU上，输出网络结构</span></span><br><span class="line">vgg = models.vgg16(pretrained=<span class="literal">True</span>)</span><br><span class="line">vgg = vgg.cuda()</span><br><span class="line">modulelist = <span class="built_in">list</span>(vgg.features.modules())</span><br><span class="line"></span><br><span class="line"><span class="comment"># 这是deep dream的实际代码，特定层的梯度被设置为等于该层的响应，这导致了该层响应最大化。换句话说，我们正在增强一层检测到的特征，对输入图像（octaves）应用梯度上升算法。</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">dd_helper</span>(<span class="params">image, layer, iterations, lr</span>):</span>        </span><br><span class="line">    <span class="comment"># 一开始的输入是图像经过预处理、在正数第一个维度上增加一个维度以匹配神经网络的输入、传到GPU上</span></span><br><span class="line">    <span class="built_in">input</span> = Variable(preprocess(image).unsqueeze(<span class="number">0</span>).cuda(), requires_grad=<span class="literal">True</span>)</span><br><span class="line">    <span class="comment"># vgg梯度清零</span></span><br><span class="line">    vgg.zero_grad()</span><br><span class="line">    <span class="comment"># 开始迭代</span></span><br><span class="line">    <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(iterations):</span><br><span class="line">        <span class="comment"># 一层一层传递输入</span></span><br><span class="line">        out = <span class="built_in">input</span></span><br><span class="line">        <span class="keyword">for</span> j <span class="keyword">in</span> <span class="built_in">range</span>(layer):</span><br><span class="line">            out = modulelist[j+<span class="number">1</span>](out)</span><br><span class="line">        <span class="comment"># 损失是输出的范数</span></span><br><span class="line">        loss = out.norm()</span><br><span class="line">        <span class="comment"># 损失反向传播</span></span><br><span class="line">        loss.backward()</span><br><span class="line">        <span class="comment"># 输入的数据是上次迭代时的输入数据+学习率×输入的梯度</span></span><br><span class="line">        <span class="built_in">input</span>.data = <span class="built_in">input</span>.data + lr * <span class="built_in">input</span>.grad.data</span><br><span class="line">    <span class="comment"># 将从网络结构中取出的输入数据的第一个维度去掉</span></span><br><span class="line">    <span class="built_in">input</span> = <span class="built_in">input</span>.data.squeeze()</span><br><span class="line">    <span class="comment"># 矩阵转置</span></span><br><span class="line">    <span class="built_in">input</span>.transpose_(<span class="number">0</span>,<span class="number">1</span>)</span><br><span class="line">    <span class="built_in">input</span>.transpose_(<span class="number">1</span>,<span class="number">2</span>)</span><br><span class="line">    <span class="comment"># 将输入逆标准化后强制截断在0到1的范围内</span></span><br><span class="line">    <span class="built_in">input</span> = np.clip(deprocess(<span class="built_in">input</span>), <span class="number">0</span>, <span class="number">1</span>)</span><br><span class="line">    <span class="comment"># 得到像素值为0到255的图像</span></span><br><span class="line">    im = Image.fromarray(np.uint8(<span class="built_in">input</span>*<span class="number">255</span>))</span><br><span class="line">    <span class="keyword">return</span> im</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 这是一个递归函数，用于创建octaves，并且将由一次递归调用生成的图像与由上一级递归调用生成的图像相融合</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">deep_dream_vgg</span>(<span class="params">image, layer, iterations, lr, octave_scale, num_octaves</span>):</span></span><br><span class="line">    <span class="comment"># 若octave序号大于0，即还未到达最底层的octave时，一层一层递归</span></span><br><span class="line">    <span class="keyword">if</span> num_octaves&gt;<span class="number">0</span>:</span><br><span class="line">        <span class="comment"># 对图像进行高斯滤波（高斯模糊）</span></span><br><span class="line">        image1 = image.<span class="built_in">filter</span>(ImageFilter.GaussianBlur(<span class="number">2</span>))</span><br><span class="line">        <span class="comment"># 判断是否缩放</span></span><br><span class="line">        <span class="keyword">if</span>(image1.size[<span class="number">0</span>]/octave_scale &lt; <span class="number">1</span> <span class="keyword">or</span> image1.size[<span class="number">1</span>]/octave_scale&lt;<span class="number">1</span>):</span><br><span class="line">            size = image1.size</span><br><span class="line">        <span class="keyword">else</span>:</span><br><span class="line">            size = (<span class="built_in">int</span>(image1.size[<span class="number">0</span>]/octave_scale), <span class="built_in">int</span>(image1.size[<span class="number">1</span>]/octave_scale))</span><br><span class="line">        <span class="comment"># 图像缩放    </span></span><br><span class="line">        image1 = image1.resize(size,Image.ANTIALIAS)</span><br><span class="line">        <span class="comment"># 递归调用，直至num_octave==0</span></span><br><span class="line">        image1 = deep_dream_vgg(image1, layer, iterations, lr, octave_scale, num_octaves-<span class="number">1</span>)</span><br><span class="line">        size = (image.size[<span class="number">0</span>], image.size[<span class="number">1</span>])</span><br><span class="line">        <span class="comment"># 将图像缩放到最初输入图像的大小</span></span><br><span class="line">        image1 = image1.resize(size,Image.ANTIALIAS)</span><br><span class="line">        <span class="comment"># 将最初输入的图像与合成的相同尺寸大小的图像融合</span></span><br><span class="line">        image = ImageChops.blend(image, image1, <span class="number">0.6</span>)</span><br><span class="line"><span class="comment">#     print(&quot;-------------- Recursive level: &quot;, num_octaves, &#x27;--------------&#x27;)</span></span><br><span class="line">    <span class="comment"># 按照dd_helper中的流程生成图像</span></span><br><span class="line">    img_result = dd_helper(image, layer, iterations, lr)</span><br><span class="line">    <span class="comment"># 图像缩放并显示</span></span><br><span class="line">    img_result = img_result.resize(image.size)</span><br><span class="line">    plt.imshow(img_result)</span><br><span class="line">    <span class="keyword">return</span> img_result</span><br><span class="line">    </span><br><span class="line"><span class="comment"># 加载图像(原始图像)</span></span><br><span class="line">sky = load_image(<span class="string">&#x27;1.jpg&#x27;</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 对于vgg16最后一个卷积层conv5_3,迭代5次，学习率为0.2,octave缩放比例为2,octave从第20层开始</span></span><br><span class="line">sky_28 = deep_dream_vgg(sky, <span class="number">28</span>, <span class="number">5</span>, <span class="number">0.2</span>, <span class="number">2</span>, <span class="number">20</span>)</span><br></pre></td></tr></table></figure></article><div class="post-reward"><div class="reward-button"><i class="fas fa-hamburger"></i> 打赏作者</div><div class="reward-main"><ul class="reward-all"><ul class="reward-group"><li class="reward-item"><a href="/img/wechat.jpg" target="_blank"><img class="post-qr-code-img" src="" data-lazy-src="/img/wechat.jpg" alt="微信"></a><div class="post-qr-code-desc">微信</div></li><li class="reward-item"><a href="/img/alipay.jpg" target="_blank"><img class="post-qr-code-img" src="" data-lazy-src="/img/alipay.jpg" alt="支付宝"></a><div class="post-qr-code-desc">支付宝</div></li></ul><a class="reward-main-btn" href="/donate"><div class="reward-text">赞赏者名单</div><div class="reward-dec">因为你们的支持让我意识到写文章的价值🙏</div></a></ul></div></div><div class="tag_share"><div class="post-meta__tag-list"><a class="post-meta__tags" href="/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/">深度学习</a><a class="post-meta__tags" href="/tags/python/">python</a><a class="post-meta__tags" href="/tags/pytorch/">pytorch</a></div><div class="post_share"><div class="social-share" data-image="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2@latest/post/pytorch.jpg" data-sites="facebook,twitter,wechat,weibo,qq"></div><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/social-share.js/dist/css/share.min.css" media="print" onload='this.media="all"'><script src="https://cdn.jsdelivr.net/npm/social-share.js/dist/js/social-share.min.js" defer="defer"></script></div></div><div class="post-copyright"><div class="post-copyright__author"><span class="post-copyright-info">深度学习 | 《深度学习入门之PyTorch》阅读笔记</span></div><div class="post-copyright__type"><span class="post-copyright-info"><a href="https://blog.justlovesmile.top/posts/bfa4054.html">https://blog.justlovesmile.top/posts/bfa4054.html</a></span></div><div class="post-copyright__notice"><span class="post-copyright-info">本博客所有文章除特别声明外，均采用 <a href="https://creativecommons.org/licenses/by-nc-sa/4.0/" target="_blank" rel="external nofollow noopener noreferrer">CC BY-NC-SA 4.0</a> 许可协议。转载请注明来自 <a href="https://blog.justlovesmile.top" target="_blank">Justlovesmile's BLOG</a>！</span></div></div><nav class="pagination-post" id="pagination"><div class="prev-post pull-left"><a href="/posts/ef2381bd.html"><img class="prev-cover" src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/v2-8dd14775ab8c91a09507f52e44f347f3_720w.jpg" onerror='onerror=null,src="/img/404.jpg"' alt="cover of previous post"><div class="pagination-info"><div class="label">上一篇</div><div class="prev_info">深度学习 | 如何理解卷积</div></div></a></div><div class="next-post pull-right"><a href="/posts/43678.html"><img class="next-cover" src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/2199659750740396073260.jpg" onerror='onerror=null,src="/img/404.jpg"' alt="cover of next post"><div class="pagination-info"><div class="label">下一篇</div><div class="next_info">深度学习 | “花书”，Deep Learning笔记</div></div></a></div></nav><div class="relatedPosts"><div class="headline"><i class="fas fa-thumbs-up fa-fw"></i><span>相关推荐</span></div><div class="relatedPosts-list"><div><a href="/posts/865c56ba.html" title="目标检测 | 常用数据集标注格式及生成脚本"><img class="cover" src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/202109111517311.jpg" alt="cover"><div class="content is-center"><div class="date"><i class="far fa-calendar-alt fa-fw"></i> 2021-09-11</div><div class="title">目标检测 | 常用数据集标注格式及生成脚本</div></div></a></div><div><a href="/posts/bb608df3.html" title="目标检测 | RetinaNet，经典单阶段Anchor-Based目标检测模型"><img class="cover" src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20220314113659.png" alt="cover"><div class="content is-center"><div class="date"><i class="far fa-calendar-alt fa-fw"></i> 2022-03-14</div><div class="title">目标检测 | RetinaNet，经典单阶段Anchor-Based目标检测模型</div></div></a></div><div><a href="/posts/fc798de3.html" title="目标检测 | Faster R-CNN，经典两阶段检测模型"><img class="cover" src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20220312220823.png" alt="cover"><div class="content is-center"><div class="date"><i class="far fa-calendar-alt fa-fw"></i> 2022-03-12</div><div class="title">目标检测 | Faster R-CNN，经典两阶段检测模型</div></div></a></div><div><a href="/posts/d150f284.html" title="深度学习 | 小样本学习基础概念"><img class="cover" src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/202201271037441.png" alt="cover"><div class="content is-center"><div class="date"><i class="far fa-calendar-alt fa-fw"></i> 2022-01-27</div><div class="title">深度学习 | 小样本学习基础概念</div></div></a></div><div><a href="/posts/6a054795.html" title="深度学习 | GAN，什么是生成对抗网络"><img class="cover" src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20210226103604.jpg" alt="cover"><div class="content is-center"><div class="date"><i class="far fa-calendar-alt fa-fw"></i> 2021-03-03</div><div class="title">深度学习 | GAN，什么是生成对抗网络</div></div></a></div><div><a href="/posts/ebe3a70b.html" title="深度学习 | Wasserstein距离"><img class="cover" src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20210131105502.jpeg" alt="cover"><div class="content is-center"><div class="date"><i class="far fa-calendar-alt fa-fw"></i> 2021-01-31</div><div class="title">深度学习 | Wasserstein距离</div></div></a></div></div></div><hr><div id="post-comment"><div class="comment-head"><div class="comment-headline"><i class="fas fa-comments fa-fw"></i> <span>评论</span></div></div><div class="comment-wrap"><div><div id="twikoo-wrap"></div></div></div></div></div><div class="aside-content" id="aside-content"><div class="card-widget card-info"><div class="is-center"><div class="avatar-img"><img src="" data-lazy-src="/img/avatar.jpg" onerror='this.onerror=null,this.src="/img/friend_404.gif"' alt="avatar"></div><div class="author-info__name">Justlovesmile</div><div class="author-info__description">一个计算机专业学生的个人博客，记录着学习笔记和生活中的思考，期待着和所有人相遇</div></div><div class="card-info-data is-center"><div class="card-info-data-item"><a href="/archives/"><div class="headline">文章</div><div class="length-num">75</div></a></div><div class="card-info-data-item"><a href="/tags/"><div class="headline">标签</div><div class="length-num">69</div></a></div><div class="card-info-data-item"><a href="/categories/"><div class="headline">分类</div><div class="length-num">6</div></a></div></div></div><div class="card-widget card-announcement"><div class="item-headline"><i class="fas fa-bullhorn card-announcement-animation"></i><span>公告</span></div><div class="announcement_content"><p>不定时更新博客，欢迎交换<a href="/friends/"><strong>友链</strong></a>...</p><div class="twopeople"><div class="container" style="height:200px"><canvas class="illo" width="800" height="800" style="max-width:200px;max-height:200px;touch-action:none;width:640px;height:640px"></canvas></div><script src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN/js/twopeople1.js"></script><script src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN/js/zdog.dist.js"></script><script id="rendered-js" src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN/js/twopeople.js"></script><style>.twopeople{margin:0;align-items:center;justify-content:center;text-align:center}canvas{display:block;margin:0 auto;cursor:move}</style></div><div style="text-align:center"><a href="https://www.foreverblog.cn/" target="_blank" rel="external nofollow noopener noreferrer"><img src="" data-lazy-src="https://img.foreverblog.cn/logo_en_default.png" alt="foreverblog" style="width:auto;height:16px"></a></div></div></div><div class="sticky_layout"><div class="card-widget" id="card-toc"><div class="item-headline"><i class="fas fa-stream"></i><span>目录</span><span class="toc-percentage"></span></div><div class="toc-content"><ol class="toc"><li class="toc-item toc-level-1"><a class="toc-link" href="#%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%85%A5%E9%97%A8%E4%B9%8BPyTorch"><span class="toc-text">深度学习入门之PyTorch</span></a><ol class="toc-child"><li class="toc-item toc-level-2"><a class="toc-link" href="#%E7%AC%AC%E4%B8%80%E7%AB%A0-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E4%BB%8B%E7%BB%8D"><span class="toc-text">第一章 深度学习介绍</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#1-1-%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD"><span class="toc-text">1.1 人工智能</span></a></li><li class="toc-item toc-level-3"><a class="toc-link" href="#1-2-%E6%95%B0%E6%8D%AE%E6%8C%96%E6%8E%98%EF%BC%8C%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E5%92%8C%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0"><span class="toc-text">1.2 数据挖掘，机器学习和深度学习</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#1-2-1-%E6%95%B0%E6%8D%AE%E6%8C%96%E6%8E%98"><span class="toc-text">1.2.1 数据挖掘</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#1-2-2-%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0"><span class="toc-text">1.2.2 机器学习</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#1-2-3-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0"><span class="toc-text">1.2.3 深度学习</span></a></li></ol></li></ol></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E7%AC%AC%E4%BA%8C%E7%AB%A0-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E6%A1%86%E6%9E%B6"><span class="toc-text">第二章 深度学习框架</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#2-1-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E6%A1%86%E6%9E%B6%E4%BB%8B%E7%BB%8D"><span class="toc-text">2.1 深度学习框架介绍</span></a></li><li class="toc-item toc-level-3"><a class="toc-link" href="#2-2-PyTorch%E4%BB%8B%E7%BB%8D"><span class="toc-text">2.2 PyTorch介绍</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#2-2-1-%E4%BB%80%E4%B9%88%E6%98%AFPyTorch"><span class="toc-text">2.2.1 什么是PyTorch</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#2-2-2-%E4%B8%BA%E4%BB%80%E4%B9%88%E4%BD%BF%E7%94%A8PyTorch"><span class="toc-text">2.2.2 为什么使用PyTorch</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#2-3-%E9%85%8D%E7%BD%AEPyTorch%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%8E%AF%E5%A2%83"><span class="toc-text">2.3 配置PyTorch深度学习环境</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#2-3-1-%E6%93%8D%E4%BD%9C%E7%B3%BB%E7%BB%9F"><span class="toc-text">2.3.1 操作系统</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#2-3-2-Python%E5%BC%80%E5%8F%91%E7%8E%AF%E5%A2%83%E7%9A%84%E5%AE%89%E8%A3%85"><span class="toc-text">2.3.2 Python开发环境的安装</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#2-3-3-PyTorch%E5%AE%89%E8%A3%85"><span class="toc-text">2.3.3 PyTorch安装</span></a></li></ol></li></ol></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E7%AC%AC%E4%B8%89%E7%AB%A0-%E5%A4%9A%E5%B1%82%E5%85%A8%E8%BF%9E%E6%8E%A5%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C"><span class="toc-text">第三章 多层全连接神经网络</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#3-1-PyTorch%E5%9F%BA%E7%A1%80"><span class="toc-text">3.1 PyTorch基础</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#3-1-1-Tensor%E5%BC%A0%E9%87%8F"><span class="toc-text">3.1.1 Tensor张量</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-1-2-Variable%EF%BC%88%E5%8F%98%E9%87%8F%EF%BC%89"><span class="toc-text">3.1.2 Variable（变量）</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-1-3-Dataset-%E6%95%B0%E6%8D%AE%E9%9B%86"><span class="toc-text">3.1.3 Dataset(数据集)</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-1-4-nn-Module-%E6%A8%A1%E7%BB%84"><span class="toc-text">3.1.4 nn.Module(模组)</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-1-5-torch-optim-%E4%BC%98%E5%8C%96"><span class="toc-text">3.1.5 torch.optim(优化)</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-1-6-%E6%A8%A1%E5%9E%8B%E7%9A%84%E4%BF%9D%E5%AD%98%E5%92%8C%E5%8A%A0%E8%BD%BD"><span class="toc-text">3.1.6 模型的保存和加载</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#3-2-%E7%BA%BF%E6%80%A7%E6%A8%A1%E5%9E%8B"><span class="toc-text">3.2 线性模型</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#3-2-1-%E4%BB%8B%E7%BB%8D"><span class="toc-text">3.2.1 介绍</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-2-2-%E4%B8%80%E7%BB%B4%E7%BA%BF%E6%80%A7%E5%9B%9E%E5%BD%92"><span class="toc-text">3.2.2 一维线性回归</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-2-3-%E5%A4%9A%E7%BB%B4%E7%BA%BF%E6%80%A7%E5%9B%9E%E5%BD%92"><span class="toc-text">3.2.3 多维线性回归</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-2-4-%E4%B8%80%E7%BB%B4%E7%BA%BF%E6%80%A7%E5%9B%9E%E5%BD%92%E7%9A%84%E4%BB%A3%E7%A0%81%E5%AE%9E%E7%8E%B0"><span class="toc-text">3.2.4 一维线性回归的代码实现</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-2-5-%E5%A4%9A%E9%A1%B9%E5%BC%8F%E5%9B%9E%E5%BD%92"><span class="toc-text">3.2.5 多项式回归</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#3-3-%E5%88%86%E7%B1%BB%E9%97%AE%E9%A2%98"><span class="toc-text">3.3 分类问题</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#3-3-1-%E9%97%AE%E9%A2%98%E4%BB%8B%E7%BB%8D"><span class="toc-text">3.3.1 问题介绍</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-3-2-Logistic%E8%B5%B7%E6%BA%90"><span class="toc-text">3.3.2 Logistic起源</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-3-3-Logistic%E5%88%86%E5%B8%83"><span class="toc-text">3.3.3 Logistic分布</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-3-4-%E4%BA%8C%E5%88%86%E7%B1%BB%E7%9A%84Logistic%E5%9B%9E%E5%BD%92"><span class="toc-text">3.3.4 二分类的Logistic回归</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-3-5-%E6%A8%A1%E5%9E%8B%E7%9A%84%E5%8F%82%E6%95%B0%E4%BC%B0%E8%AE%A1"><span class="toc-text">3.3.5 模型的参数估计</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-3-6-Logistic%E5%9B%9E%E5%BD%92%E7%9A%84%E4%BB%A3%E7%A0%81%E5%AE%9E%E7%8E%B0"><span class="toc-text">3.3.6 Logistic回归的代码实现</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#3-4-%E7%AE%80%E5%8D%95%E5%A4%9A%E5%B1%82%E5%85%A8%E8%BF%9E%E6%8E%A5%E5%89%8D%E5%90%91%E7%BD%91%E7%BB%9C"><span class="toc-text">3.4 简单多层全连接前向网络</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#3-4-1-%E6%A8%A1%E6%8B%9F%E7%A5%9E%E7%BB%8F%E5%85%83"><span class="toc-text">3.4.1 模拟神经元</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-4-2-%E5%8D%95%E5%B1%82%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E5%88%86%E7%B1%BB%E5%99%A8"><span class="toc-text">3.4.2 单层神经网络的分类器</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-4-3-%E6%BF%80%E6%B4%BB%E5%87%BD%E6%95%B0"><span class="toc-text">3.4.3 激活函数</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-4-4-%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E7%BB%93%E6%9E%84"><span class="toc-text">3.4.4 神经网络的结构</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-4-5-%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%A1%A8%E7%A4%BA%E8%83%BD%E5%8A%9B%E4%B8%8E%E5%AE%B9%E9%87%8F"><span class="toc-text">3.4.5 模型的表示能力与容量</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#3-5-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%9A%84%E5%9F%BA%E7%9F%B3%EF%BC%9A%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E7%AE%97%E6%B3%95"><span class="toc-text">3.5 深度学习的基石：反向传播算法</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#3-5-1-%E9%93%BE%E5%BC%8F%E6%B3%95%E5%88%99"><span class="toc-text">3.5.1 链式法则</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-5-2-%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E7%AE%97%E6%B3%95"><span class="toc-text">3.5.2 反向传播算法</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#3-6-%E5%90%84%E7%A7%8D%E4%BC%98%E5%8C%96%E7%AE%97%E6%B3%95%E7%9A%84%E5%8F%98%E5%BC%8F"><span class="toc-text">3.6 各种优化算法的变式</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#3-6-1-%E6%A2%AF%E5%BA%A6%E4%B8%8B%E9%99%8D%E6%B3%95"><span class="toc-text">3.6.1 梯度下降法</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-6-2-%E6%A2%AF%E5%BA%A6%E4%B8%8B%E9%99%8D%E6%B3%95%E7%9A%84%E5%8F%98%E5%BC%8F"><span class="toc-text">3.6.2 梯度下降法的变式</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#3-7-%E5%A4%84%E7%90%86%E6%95%B0%E6%8D%AE%E5%92%8C%E8%AE%AD%E7%BB%83%E6%A8%A1%E5%9E%8B%E7%9A%84%E6%8A%80%E5%B7%A7"><span class="toc-text">3.7 处理数据和训练模型的技巧</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#3-7-1-%E6%95%B0%E6%8D%AE%E9%A2%84%E5%A4%84%E7%90%86"><span class="toc-text">3.7.1 数据预处理</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-7-2-%E6%9D%83%E9%87%8D%E5%88%9D%E5%A7%8B%E5%8C%96"><span class="toc-text">3.7.2 权重初始化</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#3-7-3-%E9%98%B2%E6%AD%A2%E8%BF%87%E6%8B%9F%E5%90%88"><span class="toc-text">3.7.3 防止过拟合</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#3-8-%E5%A4%9A%E5%B1%82%E5%85%A8%E8%BF%9E%E6%8E%A5%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0MNIST%E6%89%8B%E5%86%99%E6%95%B0%E5%AD%97%E5%88%86%E7%B1%BB"><span class="toc-text">3.8 多层全连接神经网络实现MNIST手写数字分类</span></a></li></ol></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E7%AC%AC%E5%9B%9B%E7%AB%A0-%E5%8D%B7%E7%A7%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C"><span class="toc-text">第四章 卷积神经网络</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#4-1-%E4%B8%BB%E8%A6%81%E4%BB%BB%E5%8A%A1%E5%8F%8A%E8%B5%B7%E6%BA%90"><span class="toc-text">4.1 主要任务及起源</span></a></li><li class="toc-item toc-level-3"><a class="toc-link" href="#4-2-%E5%8D%B7%E7%A7%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E5%8E%9F%E7%90%86%E5%92%8C%E7%BB%93%E6%9E%84"><span class="toc-text">4.2 卷积神经网络的原理和结构</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#4-2-1-%E5%8D%B7%E7%A7%AF%E5%B1%82"><span class="toc-text">4.2.1 卷积层</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#4-2-2-%E6%B1%A0%E5%8C%96%E5%B1%82"><span class="toc-text">4.2.2 池化层</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#4-2-3-%E5%85%A8%E8%BF%9E%E6%8E%A5%E5%B1%82"><span class="toc-text">4.2.3 全连接层</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#4-2-4-%E5%8D%B7%E7%A7%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E5%9F%BA%E6%9C%AC%E5%BD%A2%E5%BC%8F"><span class="toc-text">4.2.4 卷积神经网络的基本形式</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#4-3-Pytorch%E5%8D%B7%E7%A7%AF%E6%A8%A1%E5%9D%97"><span class="toc-text">4.3 Pytorch卷积模块</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#4-3-1-%E5%8D%B7%E7%A7%AF%E5%B1%82"><span class="toc-text">4.3.1 卷积层</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#4-3-2-%E6%B1%A0%E5%8C%96%E5%B1%82"><span class="toc-text">4.3.2 池化层</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#4-3-3-%E6%8F%90%E5%8F%96%E5%B1%82%E7%BB%93%E6%9E%84"><span class="toc-text">4.3.3 提取层结构</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#4-3-4-%E6%8F%90%E5%8F%96%E5%8F%82%E6%95%B0%E5%8F%8A%E8%87%AA%E5%AE%9A%E4%B9%89%E5%88%9D%E5%A7%8B%E5%8C%96"><span class="toc-text">4.3.4 提取参数及自定义初始化</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#4-4-%E5%8D%B7%E7%A7%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E6%A1%88%E4%BE%8B%E5%88%86%E6%9E%90"><span class="toc-text">4.4 卷积神经网络案例分析</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#4-4-1-LeNet"><span class="toc-text">4.4.1 LeNet</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#4-4-2-AlexNet"><span class="toc-text">4.4.2 AlexNet</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#4-4-3-VGGNet"><span class="toc-text">4.4.3 VGGNet</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#4-4-4-GoogleNet"><span class="toc-text">4.4.4 GoogleNet</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#4-4-5-ResNet"><span class="toc-text">4.4.5 ResNet</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#4-5-%E5%86%8D%E5%AE%9E%E7%8E%B0MNIST%E6%89%8B%E5%86%99%E6%95%B0%E5%AD%97%E5%88%86%E7%B1%BB"><span class="toc-text">4.5 再实现MNIST手写数字分类</span></a></li><li class="toc-item toc-level-3"><a class="toc-link" href="#4-6-%E5%9B%BE%E5%83%8F%E5%A2%9E%E5%BC%BA%E7%9A%84%E6%96%B9%E6%B3%95"><span class="toc-text">4.6 图像增强的方法</span></a></li><li class="toc-item toc-level-3"><a class="toc-link" href="#4-7-%E5%AE%9E%E7%8E%B0cifar10%E5%88%86%E7%B1%BB"><span class="toc-text">4.7 实现cifar10分类</span></a></li></ol></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E7%AC%AC%E4%BA%94%E7%AB%A0-%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C"><span class="toc-text">第五章 循环神经网络</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#5-1-%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C"><span class="toc-text">5.1 循环神经网络</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#5-1-1-%E9%97%AE%E9%A2%98%E4%BB%8B%E7%BB%8D"><span class="toc-text">5.1.1 问题介绍</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-1-2-%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E5%9F%BA%E6%9C%AC%E7%BB%93%E6%9E%84"><span class="toc-text">5.1.2 循环神经网络的基本结构</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-1-3-%E5%AD%98%E5%9C%A8%E7%9A%84%E9%97%AE%E9%A2%98"><span class="toc-text">5.1.3 存在的问题</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#5-2-%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E5%8F%98%E5%BC%8F%EF%BC%9ALSTM%E5%92%8CGRU"><span class="toc-text">5.2 循环神经网络的变式：LSTM和GRU</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#5-2-1-LSTM"><span class="toc-text">5.2.1 LSTM</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-2-2-GRU"><span class="toc-text">5.2.2 GRU</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-2-3-%E6%94%B6%E6%95%9B%E6%80%A7%E9%97%AE%E9%A2%98"><span class="toc-text">5.2.3 收敛性问题</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#5-3-%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84PyTorch%E5%AE%9E%E7%8E%B0"><span class="toc-text">5.3 循环神经网络的PyTorch实现</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#5-3-1-PyTorch%E7%9A%84%E5%BE%AA%E7%8E%AF%E7%BD%91%E7%BB%9C%E6%A8%A1%E5%9D%97"><span class="toc-text">5.3.1 PyTorch的循环网络模块</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-3-2-%E5%AE%9E%E4%BE%8B%E4%BB%8B%E7%BB%8D"><span class="toc-text">5.3.2 实例介绍</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#5-4-%E8%87%AA%E7%84%B6%E8%AF%AD%E8%A8%80%E5%A4%84%E7%90%86%E7%9A%84%E5%BA%94%E7%94%A8"><span class="toc-text">5.4 自然语言处理的应用</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#5-4-1-%E8%AF%8D%E5%B5%8C%E5%85%A5"><span class="toc-text">5.4.1 词嵌入</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-4-2-%E8%AF%8D%E5%B5%8C%E5%85%A5%E7%9A%84PyTorch%E5%AE%9E%E7%8E%B0"><span class="toc-text">5.4.2 词嵌入的PyTorch实现</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-4-3-N-Gram%E6%A8%A1%E5%9E%8B"><span class="toc-text">5.4.3 N Gram模型</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-4-4-%E5%8D%95%E8%AF%8D%E9%A2%84%E6%B5%8B%E7%9A%84PyTorch%E5%AE%9E%E7%8E%B0"><span class="toc-text">5.4.4 单词预测的PyTorch实现</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-4-5-%E8%AF%8D%E6%80%A7%E5%88%A4%E6%96%AD"><span class="toc-text">5.4.5 词性判断</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-4-6-%E8%AF%8D%E6%80%A7%E5%88%A4%E6%96%AD%E7%9A%84PyTorch%E5%AE%9E%E7%8E%B0"><span class="toc-text">5.4.6 词性判断的PyTorch实现</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#5-5-%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E6%9B%B4%E5%A4%9A%E5%BA%94%E7%94%A8"><span class="toc-text">5.5 循环神经网络的更多应用</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#5-5-1-Many-to-one"><span class="toc-text">5.5.1 Many to one</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-5-2-Many-to-Many-shorter"><span class="toc-text">5.5.2 Many to Many (shorter)</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-5-3-Seq2seq"><span class="toc-text">5.5.3 Seq2seq</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#5-5-4-CNN-RNN"><span class="toc-text">5.5.4 CNN+RNN</span></a></li></ol></li></ol></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E7%AC%AC6%E7%AB%A0-%E7%94%9F%E6%88%90%E5%AF%B9%E6%8A%97%E7%BD%91%E7%BB%9C"><span class="toc-text">第6章 生成对抗网络</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#6-1-%E7%94%9F%E6%88%90%E6%A8%A1%E5%9E%8B"><span class="toc-text">6.1 生成模型</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#6-1-1-%E8%87%AA%E5%8A%A8%E7%BC%96%E7%A0%81%E5%99%A8"><span class="toc-text">6.1.1 自动编码器</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#6-1-2-%E5%8F%98%E5%88%86%E8%87%AA%E5%8A%A8%E7%BC%96%E7%A0%81%E5%99%A8"><span class="toc-text">6.1.2 变分自动编码器</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#6-2-%E7%94%9F%E6%88%90%E5%AF%B9%E6%8A%97%E7%BD%91%E7%BB%9C"><span class="toc-text">6.2 生成对抗网络</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#6-2-1-%E4%BB%80%E4%B9%88%E6%98%AF%E7%94%9F%E6%88%90%E5%AF%B9%E6%8A%97%E7%BD%91%E7%BB%9C"><span class="toc-text">6.2.1 什么是生成对抗网络</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#6-3-Improving-GAN"><span class="toc-text">6.3 Improving GAN</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#6-3-1-Wasserstein-GAN"><span class="toc-text">6.3.1 Wasserstein GAN</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#6-4-%E5%BA%94%E7%94%A8%E4%BB%8B%E7%BB%8D"><span class="toc-text">6.4 应用介绍</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#6-4-1-Conditional-GAN"><span class="toc-text">6.4.1 Conditional GAN</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#6-4-2-Cycle-GAN"><span class="toc-text">6.4.2 Cycle GAN</span></a></li></ol></li></ol></li><li class="toc-item toc-level-2"><a class="toc-link" href="#%E7%AC%AC%E4%B8%83%E7%AB%A0-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%AE%9E%E6%88%98"><span class="toc-text">第七章 深度学习实战</span></a><ol class="toc-child"><li class="toc-item toc-level-3"><a class="toc-link" href="#7-1-%E5%AE%9E%E4%BE%8B%E4%B8%80%EF%BC%8C%E7%8C%AB%E7%8B%97%E5%A4%A7%E6%88%98%EF%BC%9A%E8%BF%90%E7%94%A8%E9%A2%84%E8%AE%AD%E7%BB%83%E5%8D%B7%E7%A7%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E8%BF%9B%E8%A1%8C%E7%89%B9%E5%BE%81%E6%8F%90%E5%8F%96%E4%B8%8E%E9%A2%84%E8%AE%AD"><span class="toc-text">7.1 实例一，猫狗大战：运用预训练卷积神经网络进行特征提取与预训</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#7-1-1-%E8%83%8C%E6%99%AF%E4%BB%8B%E7%BB%8D"><span class="toc-text">7.1.1 背景介绍</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#7-1-2-%E5%8E%9F%E7%90%86%E5%88%86%E6%9E%90"><span class="toc-text">7.1.2 原理分析</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#7-1-3-%E4%BB%A3%E7%A0%81%E5%AE%9E%E7%8E%B0"><span class="toc-text">7.1.3 代码实现</span></a></li></ol></li><li class="toc-item toc-level-3"><a class="toc-link" href="#7-2-%E5%AE%9E%E4%BE%8B%E4%BA%8C%EF%BC%8CDeep-Dream%EF%BC%9A%E6%8E%A2%E7%B4%A2%E5%8D%B7%E7%A7%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9C%BC%E4%B8%AD%E7%9A%84%E4%B8%96%E7%95%8C"><span class="toc-text">7.2 实例二，Deep Dream：探索卷积神经网络眼中的世界</span></a><ol class="toc-child"><li class="toc-item toc-level-4"><a class="toc-link" href="#7-2-1-%E5%8E%9F%E7%90%86%E4%BB%8B%E7%BB%8D"><span class="toc-text">7.2.1 原理介绍</span></a></li><li class="toc-item toc-level-4"><a class="toc-link" href="#7-2-2-%E4%BB%A3%E7%A0%81%E5%AE%9E%E7%8E%B0"><span class="toc-text">7.2.2 代码实现</span></a></li></ol></li></ol></li></ol></li></ol></div></div><div class="card-widget card-recent-post"><div class="item-headline"><i class="fas fa-history"></i><span>最新文章</span></div><div class="aside-list"><div class="aside-list-item"><a class="thumbnail" href="/posts/56b0563d.html" title="Hexo博客 | 如何让Butterfly主题导航栏居中"><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20220315095300.png" onerror='this.onerror=null,this.src="/img/404.jpg"' alt="Hexo博客 | 如何让Butterfly主题导航栏居中"></a><div class="content"><a class="title" href="/posts/56b0563d.html" title="Hexo博客 | 如何让Butterfly主题导航栏居中">Hexo博客 | 如何让Butterfly主题导航栏居中</a><time datetime="2022-03-15T01:25:18.000Z" title="发表于 2022-03-15 09:25:18">2022-03-15</time></div></div><div class="aside-list-item"><a class="thumbnail" href="/posts/bb608df3.html" title="目标检测 | RetinaNet，经典单阶段Anchor-Based目标检测模型"><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20220314113659.png" onerror='this.onerror=null,this.src="/img/404.jpg"' alt="目标检测 | RetinaNet，经典单阶段Anchor-Based目标检测模型"></a><div class="content"><a class="title" href="/posts/bb608df3.html" title="目标检测 | RetinaNet，经典单阶段Anchor-Based目标检测模型">目标检测 | RetinaNet，经典单阶段Anchor-Based目标检测模型</a><time datetime="2022-03-14T03:26:21.000Z" title="发表于 2022-03-14 11:26:21">2022-03-14</time></div></div><div class="aside-list-item"><a class="thumbnail" href="/posts/fc798de3.html" title="目标检测 | Faster R-CNN，经典两阶段检测模型"><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/20220312220823.png" onerror='this.onerror=null,this.src="/img/404.jpg"' alt="目标检测 | Faster R-CNN，经典两阶段检测模型"></a><div class="content"><a class="title" href="/posts/fc798de3.html" title="目标检测 | Faster R-CNN，经典两阶段检测模型">目标检测 | Faster R-CNN，经典两阶段检测模型</a><time datetime="2022-03-12T13:59:01.000Z" title="发表于 2022-03-12 21:59:01">2022-03-12</time></div></div><div class="aside-list-item"><a class="thumbnail" href="/posts/d150f284.html" title="深度学习 | 小样本学习基础概念"><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/202201271037441.png" onerror='this.onerror=null,this.src="/img/404.jpg"' alt="深度学习 | 小样本学习基础概念"></a><div class="content"><a class="title" href="/posts/d150f284.html" title="深度学习 | 小样本学习基础概念">深度学习 | 小样本学习基础概念</a><time datetime="2022-01-27T02:24:38.000Z" title="发表于 2022-01-27 10:24:38">2022-01-27</time></div></div><div class="aside-list-item"><a class="thumbnail" href="/posts/e05a9ab6.html" title="Jupyter Lab | 安装、配置、插件推荐、多用户使用教程"><img src="" data-lazy-src="https://cdn.jsdelivr.net/gh/Justlovesmile/CDN2/post/image-20211125175041150.png" onerror='this.onerror=null,this.src="/img/404.jpg"' alt="Jupyter Lab | 安装、配置、插件推荐、多用户使用教程"></a><div class="content"><a class="title" href="/posts/e05a9ab6.html" title="Jupyter Lab | 安装、配置、插件推荐、多用户使用教程">Jupyter Lab | 安装、配置、插件推荐、多用户使用教程</a><time datetime="2021-11-25T09:38:43.000Z" title="发表于 2021-11-25 17:38:43">2021-11-25</time></div></div></div></div></div></div></main><footer id="footer"><div id="footer-wrap"><div id="footer_deal"><a class="social-icon" href="mailto:865717150@qq.com" target="_blank" title="Email" rel="external nofollow noopener noreferrer"><i class="fas fa-envelope"></i></a><a class="social-icon" href="https://blog.csdn.net/qq_43701912" target="_blank" title="CSDN" rel="external nofollow noopener noreferrer"><i class="iconfont icon-csdn1"></i></a><a class="social-icon" href="https://github.com/Justlovesmile" target="_blank" title="Github" rel="external nofollow noopener noreferrer"><i class="fab fa-github"></i></a><a class="social-icon" href="https://weibo.com/u/5252319712" target="_blank" title="微博" rel="external nofollow noopener noreferrer"><i class="fa fa-weibo"></i></a><a class="social-icon" href="https://space.bilibili.com/168738824" target="_blank" title="Bilibili" rel="external nofollow noopener noreferrer"><i class="fas iconfont icon-bilibili"></i></a></div><div id="mj-footer"><div class="footer-group"><h3 class="footer-title">关于</h3><div class="footer-links"><a class="footer-item" target="_blank" rel="external nofollow noopener noreferrer" href="https://www.justlovesmile.top/">个人主页</a><a class="footer-item" href="/donate/">赞赏博主</a><a class="footer-item" href="/update/">博客日志</a><a class="footer-item" href="/charts/">博客统计</a></div></div><div class="footer-group"><h3 class="footer-title">归档</h3><div class="footer-links"><a class="footer-item" href="/archives/">文章归档</a><a class="footer-item" href="/tags/">全部标签</a><a class="footer-item" href="/categories/">全部分类</a><a class="footer-item" href="/random/">随机文章</a></div></div><div class="footer-group"><h3 class="footer-title">导航</h3><div class="footer-links"><a class="footer-item" href="/guestbook/">博客留言</a><a class="footer-item" href="/friends/">友情链接</a><a class="footer-item" href="/fcircle/">友链订阅</a><a class="footer-item" href="/atom.xml">RSS订阅</a></div></div><div class="footer-group"><h3 class="footer-title">协议</h3><div class="footer-links"><a class="footer-item" href="/privacy/">隐私协议</a><a class="footer-item" href="/cookies/">Cookies</a><a class="footer-item" href="/cc/">版权协议</a></div></div></div><div id="footer-banner"><div class="footer-banner-links"><div class="footer-banner-left"><div id="footer-banner-tips">©2019 - 2022 By Justlovesmile</div></div><div class="footer-banner-right"><a class="footer-banner-link" target="_blank" rel="external nofollow noopener noreferrer" href="http://beian.miit.gov.cn/">蜀ICP备20004960号</a><a class="footer-banner-link" href="/update/">主题</a><a class="footer-banner-link" href="/about/">关于</a></div></div></div></div></footer></div><div id="rightside"><div id="rightside-config-hide"><button id="readmode" type="button" title="阅读模式"><i class="fas fa-book-open"></i></button><button id="darkmode" type="button" title="浅色和深色模式转换"><i class="fas fa-adjust"></i></button></div><div id="rightside-config-show"><button id="rightside_config" type="button" title="设置"><i class="fas fa-cog fa-spin"></i></button><button class="close" id="mobile-toc-button" type="button" title="目录"><i class="fas fa-list-ul"></i></button><a id="to_comment" href="#post-comment" title="直达评论"><i class="fas fa-comments"></i></a><button id="go-up" type="button" title="回到顶部"><i class="fas fa-arrow-up"></i></button></div></div><div id="local-search"><div class="search-dialog"><nav class="search-nav"><span class="search-dialog-title">本地搜索</span><span id="loading-status"></span><button class="search-close-button"><i class="fas fa-times"></i></button></nav><div class="is-center" id="loading-database"><i class="fas fa-spinner fa-pulse"></i> <span>数据库加载中</span></div><div class="search-wrap"><div id="local-search-input"><div class="local-search-box"><input class="local-search-box--input" placeholder="搜索文章" type="text"></div></div><hr><div id="local-search-results"></div></div></div><div id="search-mask"></div></div><div><script src="/js/utils.js"></script><script src="/js/main.js"></script><script defer="defer" src="https://cdn.jsdelivr.net/npm/@fancyapps/ui/dist/fancybox.umd.js"></script><script defer="defer" src="https://cdn.jsdelivr.net/npm/instant.page/instantpage.min.js" type="module"></script><script defer="defer" src="https://cdn.jsdelivr.net/npm/vanilla-lazyload/dist/lazyload.iife.min.js"></script><script defer="defer" src="https://cdn.jsdelivr.net/npm/node-snackbar/dist/snackbar.min.js"></script><script defer="defer" src="/js/search/local-search.js"></script><div class="js-pjax"><script>if(window.MathJax)MathJax.startup.document.state(0),MathJax.texReset(),MathJax.typeset();else{window.MathJax={tex:{inlineMath:[["$","$"],["\\(","\\)"]],tags:"ams"},chtml:{scale:1.2},options:{renderActions:{findScript:[10,t=>{for(const e of document.querySelectorAll('script[type^="math/tex"]')){const a=!!e.type.match(/; *mode=display/),n=new t.options.MathItem(e.textContent,t.inputJax[0],a),s=document.createTextNode("");e.parentNode.replaceChild(s,e),n.start={node:s,delim:"",n:0},n.end={node:s,delim:"",n:0},t.math.push(n)}},""],insertScript:[200,()=>{document.querySelectorAll("mjx-container:not([display])").forEach(t=>{const e=t.parentNode;"li"===e.nodeName.toLowerCase()?e.parentNode.classList.add("has-jax"):e.classList.add("has-jax")})},"",!1]}}};const t=document.createElement("script");t.src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js",t.id="MathJax-script",t.async=!0,document.head.appendChild(t)}</script><script>(()=>{const t=()=>{twikoo.init(Object.assign({el:"#twikoo-wrap",envId:"blog-comment-3gt33nkmf9f97e6e",region:"ap-shanghai",onCommentLoaded:function(){btf.loadLightbox(document.querySelectorAll("#twikoo .tk-content img:not(.vemoji)"))}},null))},o=()=>{"object"!=typeof twikoo?getScript("https://cdn.jsdelivr.net/npm/twikoo/dist/twikoo.all.min.js").then(t):setTimeout(t,0)};btf.loadComment(document.getElementById("twikoo-wrap"),o)})()</script></div><canvas id="universe"></canvas><script defer="defer">console.log("\n %c 欢迎来到Justlovesmile の Blog %c https://github.com/Justlovesmile %c https://blog.justlovesmile.top \n","color: #f9ed69; background: #252a34; padding:5px 0;","background: #3fc1c9; padding:5px 0;","background: #3fc1c9; padding:5px 0;")</script><script defer="defer" src="/js/rgbaster.min.js"></script><script defer="defer" src="/js/justlovesmile.js"></script><script>window.addEventListener("load",async()=>{navigator.serviceWorker.register("/js/sw-cdn.js?time="+(new Date).getTime()).then(async e=>{"true"!=window.localStorage.getItem("install")&&(window.localStorage.setItem("install","true"),setTimeout(()=>{window.location.search="?time="+(new Date).getTime()},1e3))}).catch(e=>{console.log("sw-cdn.js error")})})</script></div></body></html>